Assignment 4¶

In [2]:
# Importing Dependencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import seaborn as sns
sns.set_style('whitegrid')

import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as Fun
import torch.optim as optim

Question 1 [1 Marks]¶

Implement Logistic Regression using the Pyro library referring [1] for guidance.

Show both the mean prediction as well as standard deviation in the predictions over the 2d grid. Use NUTS MCMC sampling to sample the posterior. Take 1000 samples for posterior distribution and use 500 samples as burn/warm up. Use the below given dataset.

In [3]:
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

X, y = make_moons(n_samples=100, noise=0.3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2, random_state=41)
# X_train = torch.tensor(X_train); y_train = torch.tensor(y_train)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Dataset')
plt.show()
In [4]:
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS

# Defining the logistic regression pyro model

def logistic_regression(X, y=None):
    # The weights and bias are random variables, we assume the weights to be normally distributed
    weight = pyro.sample("weight", dist.Normal(0, 1).expand([X.shape[1]]))
    bias = pyro.sample("bias", dist.Normal(0, 1))

    # Compute the logits
    logits = torch.matmul(X, weight) + bias

    # Model observation (likelihood)
    # Repeated block of code repeated X.shape[0] times named "data"
    with pyro.plate("data", X.shape[0]):
        obs = pyro.sample("obs", dist.Bernoulli(logits=logits), obs=y)
In [5]:
# Running the model on the data

model = logistic_regression; num_samples = 1000; burn_in = 500
posterior = MCMC(NUTS(model), num_samples=num_samples, warmup_steps=burn_in)
posterior.run(X_train, y_train)

weights_samples = posterior.get_samples()["weight"]
bias_samples = posterior.get_samples()["bias"]

weights_samples.shape, bias_samples.shape
Sample: 100%|██████████| 1500/1500 [00:08, 186.67it/s, step size=7.36e-01, acc. prob=0.886]
Out[5]:
(torch.Size([1000, 2]), torch.Size([1000]))
In [6]:
# Define the grid over which you want to make predictions
grid_x = np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100)
grid_y = np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100)
grid_xx, grid_yy = np.meshgrid(grid_x, grid_y)
grid = np.column_stack([grid_xx.ravel(), grid_yy.ravel()])

grid.shape
Out[6]:
(10000, 2)
In [7]:
# Calculate the logits for each sample
logits_samples = torch.matmul(torch.tensor(grid, dtype=torch.float32), weights_samples.t()) + bias_samples
prob_samples = torch.sigmoid(logits_samples)
mean_predictions, std_predictions = prob_samples.mean(1), prob_samples.std(1)

prob_samples.shape, mean_predictions.shape, std_predictions.shape
Out[7]:
(torch.Size([10000, 1000]), torch.Size([10000]), torch.Size([10000]))
In [8]:
mean_predictions = mean_predictions.reshape(grid_xx.shape)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

contour1 = ax[0].contourf(grid_xx, grid_yy, mean_predictions, levels=50, cmap="RdBu", alpha=0.6)
colorbar1 = plt.colorbar(contour1, ax=ax[0], label='Mean Prediction')

contour2 = ax[1].contourf(grid_xx, grid_yy, mean_predictions, levels=50, cmap="RdBu", alpha=0.6)
colorbar2 = plt.colorbar(contour2, ax=ax[1], label='Mean Prediction')

ax[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap="RdBu_r", edgecolor='k', label = 'Training Data')
ax[1].scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap="RdBu_r", edgecolor='k', label = 'Test Data')
ax[0].legend(); ax[1].legend()
fig.suptitle('Logistic Regression Mean Prediction')
Out[8]:
Text(0.5, 0.98, 'Logistic Regression Mean Prediction')
In [9]:
std_predictions = std_predictions.reshape(grid_xx.shape)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))

contour1 = ax[0].contourf(grid_xx, grid_yy, std_predictions, levels=50, cmap="viridis", alpha=0.6)
colorbar1 = plt.colorbar(contour1, ax=ax[0], label='Std Prediction')

contour2 = ax[1].contourf(grid_xx, grid_yy, std_predictions, levels=50, cmap="viridis", alpha=0.6)
colorbar2 = plt.colorbar(contour2, ax=ax[1], label='Std Prediction')

ax[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap="RdBu_r", edgecolor='k', label = 'Training Data')
ax[1].scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap="RdBu_r", edgecolor='k', label = 'Test Data')
ax[0].legend(); ax[1].legend()
fig.suptitle('Logistic Regression Std Prediction')
Out[9]:
Text(0.5, 0.98, 'Logistic Regression Std Prediction')

Question 2 [2 Marks]¶

Consider the FVC dataset example discussed in the class.

We had only used the train dataset. Now, we want to find out the performance of various models on the test dataset.

Use the given dataset and deduce which model works best in terms of error (MAE) and coverage? The base model is Linear Regression by Sklearn (from sklearn.linear_model import LinearRegression). Plot the trace diagrams and posterior distribution.

Also plot the predictive posterior distribution with 90% confidence interval.

In [10]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
from jax import random

import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.constrained_layout.use'] = True

import seaborn as sns
sns.set_context("notebook")

import numpyro
import numpyro.distributions as dist
In [11]:
import os
import requests

URL = "https://gist.githubusercontent.com/ucals/" + "2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"+ "43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"+ "osic_pulmonary_fibrosis.csv"

if not os.path.exists("osic_pulmonary_fibrosis.csv"):
    response = requests.get(URL)

    with open("osic_pulmonary_fibrosis.csv", "wb") as f:
        f.write(response.content)

train = pd.read_csv("osic_pulmonary_fibrosis.csv")
train.head()
Out[11]:
Patient Weeks FVC Percent Age Sex SmokingStatus
0 ID00007637202177411956430 -4 2315 58.253649 79 Male Ex-smoker
1 ID00007637202177411956430 5 2214 55.712129 79 Male Ex-smoker
2 ID00007637202177411956430 7 2061 51.862104 79 Male Ex-smoker
3 ID00007637202177411956430 9 2144 53.950679 79 Male Ex-smoker
4 ID00007637202177411956430 11 2069 52.063412 79 Male Ex-smoker
In [12]:
train.describe()
Out[12]:
Weeks FVC Percent Age
count 1549.000000 1549.000000 1549.000000 1549.000000
mean 31.861846 2690.479019 77.672654 67.188509
std 23.247550 832.770959 19.823261 7.057395
min -5.000000 827.000000 28.877577 49.000000
25% 12.000000 2109.000000 62.832700 63.000000
50% 28.000000 2641.000000 75.676937 68.000000
75% 47.000000 3171.000000 88.621065 72.000000
max 133.000000 6399.000000 153.145378 88.000000
In [13]:
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)

X = train.drop(columns=["Patient", "FVC", "Percent", "Age", "Sex", "SmokingStatus"])
y = train["FVC"]
x_train, x_test, y_train, y_test = train_test_split(X, y, train_size = 0.8, random_state = 0)
In [14]:
x_train
Out[14]:
Weeks patient_code
295 63 33
516 63 59
655 46 74
838 48 94
452 61 51
... ... ...
763 27 86
835 18 94
1216 38 138
559 73 63
684 18 77

1239 rows × 2 columns

In [15]:
len(x_train), len(x_test), len(y_train), len(y_test)
Out[15]:
(1239, 310, 1239, 310)
In [16]:
sample_patient_code_train = x_train["patient_code"].values
sample_patient_code_test = x_test["patient_code"].values
x_train = x_train["Weeks"]
x_test = x_test["Weeks"]

# Converting into numpy arrays
x_train = np.array(x_train); x_test = np.array(x_test)
y_train = np.array(y_train); y_test = np.array(y_test)
In [17]:
x_train, y_train
Out[17]:
(array([63, 63, 46, ..., 38, 73, 18]),
 array([2957, 3327, 2205, ..., 3882, 3907, 3054]))

Vanilla Linear Regression¶

In [18]:
### Linear regression from scikit-learn
from sklearn.linear_model import LinearRegression

lr = LinearRegression()
lr.fit(x_train.reshape(-1, 1), y_train)
Out[18]:
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
In [19]:
lr_sigma = np.std(y_train - lr.predict(x_train.reshape(-1, 1)))
lr.coef_, lr.intercept_, lr_sigma
Out[19]:
(array([-1.68051699]), 2747.966054074546, 843.8124655154415)
In [20]:
# Plot the data and the regression line
all_weeks = np.arange(-12, 134, 1)

plt.scatter(x_train, y_train, alpha=0.3, label = "Train data")
plt.scatter(x_test, y_test, color="red", alpha=0.3, label = "Test data")
plt.plot(all_weeks, lr.predict(all_weeks.reshape(-1, 1)), color="black", lw=2, label = "Linear Regression")
plt.fill_between(all_weeks, lr.predict(all_weeks.reshape(-1, 1)) - 1.96*lr_sigma, 
                 lr.predict(all_weeks.reshape(-1, 1)) + 1.96*lr_sigma, alpha=0.2, label = "Std fill")
plt.title("Linear Regression fit")
plt.legend()
plt.xlabel("Weeks")
plt.ylabel("FVC")
Out[20]:
Text(0, 0.5, 'FVC')
In [21]:
# Finding the mean absolute error on train set
from sklearn.metrics import mean_absolute_error

maes = {}
maes["LinearRegression on train"] = mean_absolute_error(y_train, lr.predict(x_train.reshape(-1, 1)))
maes["LinearRegression on test"] = mean_absolute_error(y_test, lr.predict(x_test.reshape(-1, 1)))
maes
Out[21]:
{'LinearRegression on train': 662.1236659544445,
 'LinearRegression on test': 626.3184730275215}
In [22]:
# Finding the 95% coverage on train set

def coverage(y_true, y_pred, sigma):
    lower = y_pred - 1.96 * sigma
    upper = y_pred + 1.96 * sigma
    return np.mean((y_true >= lower) & (y_true <= upper))

coverages = {}
print("Train Coverage: ", coverage(y_train, lr.predict(x_train.reshape(-1, 1)), lr_sigma))
coverages["LinearRegression on test"] = coverage(y_test, lr.predict(x_test.reshape(-1, 1)), lr_sigma)
coverages["LinearRegression on train"] = coverage(y_train, lr.predict(x_train.reshape(-1, 1)), lr_sigma)
coverages
Train Coverage:  0.9548022598870056
Out[22]:
{'LinearRegression on test': 0.9709677419354839,
 'LinearRegression on train': 0.9548022598870056}

Pooled model¶

$\alpha \sim \text{Normal}(0, 500)$

$\beta \sim \text{Normal}(0, 500)$

$\sigma \sim \text{HalfNormal}(100)$

for i in range(N_Weeks):

$FVC_i \sim \text{Normal}(\alpha + \beta \cdot Week_i, \sigma)$

In [23]:
def pooled_model(sample_weeks, sample_fvc=None):
    α = numpyro.sample("α", dist.Normal(0., 500.))
    β = numpyro.sample("β", dist.Normal(0., 500.))
    σ = numpyro.sample("σ", dist.HalfNormal(50.))
    with numpyro.plate("samples", len(sample_weeks)):
        fvc = numpyro.sample("fvc", dist.Normal(α + β * sample_weeks, σ), obs=sample_fvc)
    return fvc

sample_weeks = train["Weeks"].values
sample_fvc = train["FVC"].values
In [24]:
from numpyro.infer import MCMC, NUTS, Predictive
nuts_kernel = NUTS(pooled_model)

mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, sample_weeks=x_train, sample_fvc=y_train)
posterior_samples = mcmc.get_samples()
sample: 100%|██████████| 6000/6000 [00:02<00:00, 2206.33it/s, 7 steps of size 4.57e-01. acc. prob=0.91]
In [25]:
import arviz as az

idata = az.from_numpyro(mcmc)
az.plot_trace(idata, compact=True);
In [26]:
# Summary statistics
az.summary(idata, round_to=2)
arviz - WARNING - Shape validation failed: input_shape: (1, 4000), minimum_shape: (chains=2, draws=4)
Out[26]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
α 2731.78 37.24 2658.75 2798.00 0.90 0.64 1713.98 1879.14 NaN
β -1.36 0.93 -3.01 0.44 0.02 0.02 1649.09 1842.16 NaN
σ 773.41 12.98 750.62 799.75 0.25 0.18 2697.11 2607.93 NaN

Making predictions

In [27]:
predictive = Predictive(pooled_model, mcmc.get_samples())
predictions = predictive(rng_key, all_weeks, None)
print(predictions["fvc"].shape)

pd.DataFrame(predictions["fvc"]).mean().plot(label = "Pooled Model")
plt.plot(all_weeks, lr.predict(all_weeks.reshape(-1, 1)), color="black", lw=2, label = "Linear Regression")
plt.title("Pooled Model Predictions")
plt.xlabel("Weeks"); plt.ylabel("FVC")
plt.legend()
(4000, 146)
Out[27]:
<matplotlib.legend.Legend at 0x153679e50>
In [28]:
# Predictive distribution
predictive = Predictive(pooled_model, mcmc.get_samples())
predictions = predictive(rng_key, all_weeks, None)

# Get the mean and standard deviation of the predictions
mu = predictions["fvc"].mean(axis=0)
sigma = predictions["fvc"].std(axis=0)

# Plot the predictions
plt.plot(all_weeks, mu)
plt.fill_between(all_weeks, mu - 1.96*sigma, mu + 1.96*sigma, alpha=0.2, label = "Std fill")
plt.scatter(sample_weeks, sample_fvc, alpha=0.2, label = "Train data")
plt.scatter(x_test, y_test, color="red", alpha=0.3, label = "Test data")
plt.xlabel("Weeks")
plt.title("Predictive distribution of Pooled Model")
plt.ylabel("FVC")
plt.legend()
Out[28]:
<matplotlib.legend.Legend at 0x1546d2290>
In [29]:
### Computing Mean Absolute Error and Coverage at 95% confidence interval

preds_pooled  = predictive(rng_key, x_train, None)['fvc']
predictions_train_pooled = preds_pooled.mean(axis=0)
std_train_pooled = preds_pooled.std(axis=0)

pred_test_pooled = predictive(rng_key, x_test, None)['fvc']
predictions_test_pooled = pred_test_pooled.mean(axis=0)
std_test_pooled = pred_test_pooled.std(axis=0)

maes["PooledModel on train"] = mean_absolute_error(y_train, predictions_train_pooled)
maes["PooledModel on test"] = mean_absolute_error(y_test, predictions_test_pooled)
maes
Out[29]:
{'LinearRegression on train': 662.1236659544445,
 'LinearRegression on test': 626.3184730275215,
 'PooledModel on train': 661.24235853637,
 'PooledModel on test': 626.4491242439516}
In [30]:
### Computing the coverage at 95% confidence interval

coverages["PooledModel on test"] = coverage(y_test, predictions_test_pooled, std_test_pooled).item()
coverages["PooledModel on train"] = coverage(y_train, predictions_train_pooled, std_train_pooled).item()
coverages
Out[30]:
{'LinearRegression on test': 0.9709677419354839,
 'LinearRegression on train': 0.9548022598870056,
 'PooledModel on test': 0.948387086391449,
 'PooledModel on train': 0.938660204410553}

Partially pooled model with the same sigma¶

In [31]:
### Hierarchical model

def paritally_pooled_model(sample_weeks, sample_patient_code, sample_fvc=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))

    n_patients = len(np.unique(sample_patient_code))

    with numpyro.plate("Participants", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[sample_patient_code] + β[sample_patient_code] * sample_weeks

    with numpyro.plate("data", len(sample_patient_code)):
        numpyro.sample("fvc", dist.Normal(FVC_est, σ), obs=sample_fvc)
In [32]:
model_kwargs_train = {"sample_weeks": x_train, "sample_patient_code": sample_patient_code_train, "sample_fvc": y_train}
model_kwargs_test = {"sample_weeks": x_test, "sample_patient_code": sample_patient_code_test, "sample_fvc": y_test}
In [33]:
nuts_final = NUTS(paritally_pooled_model)

mcmc_final = MCMC(nuts_final, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)

mcmc_final.run(rng_key, **model_kwargs_train)
sample: 100%|██████████| 6000/6000 [00:15<00:00, 388.28it/s, 63 steps of size 1.73e-02. acc. prob=0.85] 
In [34]:
predictive_final = Predictive(paritally_pooled_model, mcmc_final.get_samples())
az.plot_trace(az.from_numpyro(mcmc_final), compact=True);

Getting Mse and Coverage

In [35]:
predictive_hierarchical = Predictive(paritally_pooled_model, mcmc_final.get_samples())

predictions_train_hierarchical = predictive_final(rng_key,
                                sample_weeks = model_kwargs_train["sample_weeks"],
                                sample_patient_code = model_kwargs_train["sample_patient_code"])['fvc']

mu_predictions_train_h = predictions_train_hierarchical.mean(axis=0)
std_predictions_train_h = predictions_train_hierarchical.std(axis=0)

maes["Hierarchical on train"] = mean_absolute_error(y_train, mu_predictions_train_h)

coverages["Hierarchical on train"] = coverage(y_train, mu_predictions_train_h, std_predictions_train_h).item()

predictions_test_hierarchical = predictive_final(rng_key,
                                sample_weeks = model_kwargs_test["sample_weeks"],
                                sample_patient_code = model_kwargs_test["sample_patient_code"])['fvc']

mu_predictions_test_h = predictions_test_hierarchical.mean(axis=0)
std_predictions_test_h = predictions_test_hierarchical.std(axis=0)

maes["Hierarchical on test"] = mean_absolute_error(y_test, mu_predictions_test_h)

coverages["Hierarchical on test"] = coverage(y_test, mu_predictions_test_h, std_predictions_test_h).item()
In [36]:
maes
Out[36]:
{'LinearRegression on train': 662.1236659544445,
 'LinearRegression on test': 626.3184730275215,
 'PooledModel on train': 661.24235853637,
 'PooledModel on test': 626.4491242439516,
 'Hierarchical on train': 80.57488639283507,
 'Hierarchical on test': 110.49720419606855}
In [37]:
coverages
Out[37]:
{'LinearRegression on test': 0.9709677419354839,
 'LinearRegression on train': 0.9548022598870056,
 'PooledModel on test': 0.948387086391449,
 'PooledModel on train': 0.938660204410553,
 'Hierarchical on train': 0.9741727113723755,
 'Hierarchical on test': 0.9387096762657166}
In [38]:
# Predict for a given patient

def predict_final(patient_code):
    predictions = predictive_final(rng_key, all_weeks, patient_code)
    mu = predictions["fvc"].mean(axis=0)
    sigma = predictions["fvc"].std(axis=0)
    return mu, sigma

# Plot the predictions for a given patient
def plot_patient_final(patient_code):

    mu, sigma = predict_final(patient_code)
    plt.plot(all_weeks, mu)
    plt.fill_between(all_weeks, mu - sigma, mu + sigma, alpha=0.1)
    id_to_patient = patient_encoder.inverse_transform([patient_code])[0]

    patient_weeks = train[train["Patient"] == id_to_patient]["Weeks"]
    patient_fvc = train[train["Patient"] == id_to_patient]["FVC"]
    plt.scatter(patient_weeks, patient_fvc, alpha=0.5)
    plt.xlabel("Weeks")
    plt.ylabel("FVC")
    plt.title(patient_encoder.inverse_transform([patient_code])[0])
In [39]:
# plot for a given patient
plot_patient_final(np.array([0]))

Partially pooled model with the sigma hyperpriors¶

In [40]:
### Hierarchical model

def Partially_pooled_sigma_model(sample_weeks, sample_patient_code, sample_fvc=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))
    𝛄_σ = numpyro.sample("𝛄_σ", dist.HalfNormal(30.0))
    n_patients = len(np.unique(sample_patient_code))

    with numpyro.plate("Participants", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))
        σ = numpyro.sample("σ", dist.Exponential(𝛄_σ))
    FVC_est = α[sample_patient_code] + β[sample_patient_code] * sample_weeks

    with numpyro.plate("data", len(sample_patient_code)):
        numpyro.sample("fvc", dist.Normal(FVC_est, σ[sample_patient_code]), obs=sample_fvc)
In [41]:
nuts_kernel = NUTS(Partially_pooled_sigma_model)

mcmc_3 = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)

mcmc_3.run(rng_key, **model_kwargs_train)
sample: 100%|██████████| 6000/6000 [00:26<00:00, 224.98it/s, 63 steps of size 5.69e-02. acc. prob=0.87] 
In [42]:
az.plot_trace(az.from_numpyro(mcmc_3), compact=True);
In [43]:
predictive = Predictive(Partially_pooled_sigma_model, mcmc_3.get_samples())

predictive_train_3 = predictive(rng_key,sample_weeks = model_kwargs_train["sample_weeks"],
                                sample_patient_code = model_kwargs_train["sample_patient_code"])['fvc']

mu_predictions_train_3 = predictive_train_3.mean(axis=0)
std_predictions_train_3 = predictive_train_3.std(axis=0)

maes["Hierarchial sigma train"] = mean_absolute_error(y_train, mu_predictions_train_3)
coverages["Hierarchial sigma train"] = coverage(y_train, mu_predictions_train_3, std_predictions_train_3).item()

predictive_test_3 = predictive(rng_key,sample_weeks = model_kwargs_test["sample_weeks"],
                               sample_patient_code = model_kwargs_test["sample_patient_code"])['fvc']

mu_predictions_test_3 = predictive_test_3.mean(axis=0)
std_predictions_test_3 = predictive_test_3.std(axis=0)

maes["Hierarchial sigma test"] = mean_absolute_error(y_test, mu_predictions_test_3)
coverages["Hierarchial sigma test"] = coverage(y_test, mu_predictions_test_3, std_predictions_test_3).item()
In [44]:
# Predict for a given patient
def predict_final_3(patient_code):
    predictions = predictive(rng_key, all_weeks, patient_code)
    mu = predictions["fvc"].mean(axis=0)
    sigma = predictions["fvc"].std(axis=0)
    return mu, sigma

# Plot the predictions for a given patient
def plot_patient_final_3(patient_code):
    mu, sigma = predict_final_3(patient_code)
    plt.plot(all_weeks, mu)
    plt.fill_between(all_weeks, mu - sigma, mu + sigma, alpha=0.1)
    id_to_patient = patient_encoder.inverse_transform([patient_code])[0]
    #print(id_to_patient[0], patient_code)
    #print(patient_code, id_to_patient)
    patient_weeks = train[train["Patient"] == id_to_patient]["Weeks"]
    patient_fvc = train[train["Patient"] == id_to_patient]["FVC"]
    plt.scatter(patient_weeks, patient_fvc, alpha=0.5)
    #plt.scatter(sample_weeks[train["patient_code"] == patient_code.item()], fvc[train["patient_code"] == patient_code.item()], alpha=0.5)
    plt.xlabel("Weeks")
    plt.ylabel("FVC")
    plt.title(patient_encoder.inverse_transform([patient_code])[0])
In [45]:
# plot for a given patient
plot_patient_final_3(np.array([32]))
In [48]:
pd.Series(maes)
Out[48]:
LinearRegression on train    662.123666
LinearRegression on test     626.318473
PooledModel on train         661.242359
PooledModel on test          626.449124
Hierarchical on train         80.574886
Hierarchical on test         110.497204
Hierarchial sigma train       85.884296
Hierarchial sigma test       111.439091
dtype: float64
In [49]:
pd.Series(coverages)
Out[49]:
LinearRegression on test     0.970968
LinearRegression on train    0.954802
PooledModel on test          0.948387
PooledModel on train         0.938660
Hierarchical on train        0.974173
Hierarchical on test         0.938710
Hierarchial sigma train      0.996772
Hierarchial sigma test       0.948387
dtype: float64

Question 3 [4 Marks]¶

Use your version of following models to reproduce figure 4 from the paper referenced at [2].

You can also refer to the notebook in the course.

1) Hypernet [2 marks] 2) Neural Processes [2 marks]


Hypernet¶

In [ ]:
import torch
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [ ]:
import os

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

root = "/content/drive/MyDrive/Dataset/CelebA_Dataset"
celeba_dataset = []

for filename in os.listdir(root):
    if filename.endswith('.jpg'):
        image_path = os.path.join(root, filename)
        image = datasets.folder.default_loader(image_path)
        image = transform(image)
        celeba_dataset.append(image)
In [ ]:
len(celeba_dataset)
Out[ ]:
761
In [ ]:
img0 = transform(datasets.folder.default_loader("/content/drive/MyDrive/Dataset/CelebA_Dataset/000001.jpg"))
print(img0.shape)
plt.imshow(img0.permute(1,2,0))
torch.Size([3, 64, 64])
Out[ ]:
<matplotlib.image.AxesImage at 0x7c8a5c02a1d0>
In [ ]:
from sklearn import preprocessing

def create_scaled_cmap(img, rt = False):
    """
    Creates a scaled image and a scaled colormap
    """
    img= img
    num_channels, height, width = img.shape

    # Create a 2D grid of (x,y) coordinates
    x_coords = torch.arange(width).repeat(height, 1)
    y_coords = torch.arange(height).repeat(width, 1).t()
    x_coords = x_coords.reshape(-1)
    y_coords = y_coords.reshape(-1)

    X = torch.stack([x_coords, y_coords], dim=1).float().to(device)
    # Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
    if rt == True:
      Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
    else:
      Y = img.reshape(-1, num_channels).float().to(device)

    scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(X.cpu())
    scaled_X = torch.tensor(scaler_X.transform(X.cpu())).to(device).float()

    return scaled_X, Y, scaler_X
In [ ]:
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0)
img0_X.shape, img0_Y.shape
Out[ ]:
(torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
s = 64

class NN(nn.Module):
    def _init_siren(self, activation_scale):
        self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
        for layers in [self.fc2, self.fc3, self.fc5]:
            layers.weight.data.uniform_(-np.sqrt(6/self.fc2.in_features)/activation_scale,
                                        np.sqrt(6/self.fc2.in_features)/activation_scale)

    def __init__(self, activation=torch.sin, n_out=3, activation_scale=1.0):
        super().__init__()
        self.activation = activation
        self.activation_scale = activation_scale
        self.fc1 = nn.Linear(2, s)
        self.fc2 = nn.Linear(s, s)
        self.fc3 = nn.Linear(s, s)
        self.fc5 = nn.Linear(s, 3) #gray scale image (1) or RGB (3)
        if self.activation == torch.sin:
            # init weights and biases for sine activation
            self._init_siren(activation_scale=self.activation_scale)

    def forward(self, x):
        x = self.activation(self.activation_scale*self.fc1(x))
        x = self.activation(self.activation_scale*self.fc2(x))
        x = self.activation(self.activation_scale*self.fc3(x))
        # x = self.activation(self.activation_scale*self.fc4(x))
        return self.fc5(x)

Trying only the normal neural net Relu and Sin activations¶

This will act as our Target net

In [ ]:
# Making singular data for only 1 image
torch.manual_seed(0)
sh_index = torch.randperm(img0_X.shape[0])

# Shuffle the dataset
img0_X_sh = img0_X[sh_index]
img0_Y_sh = img0_Y[sh_index]
sh_index[0:10]
Out[ ]:
tensor([2732, 1810, 3111, 2738,  155, 2864, 2423, 2918, 2441, 3201])
In [ ]:
torch.manual_seed(0)

nns = {}
nns["img0"] = {}
nns["img0"]["relu"] = NN(activation=torch.relu, n_out=3).to(device)
nns["img0"]["sin"] = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device)
In [ ]:
nns["img0"]["relu"](img0_X_sh).shape, nns["img0"]["sin"](img0_X_sh).shape
Out[ ]:
(torch.Size([4096, 3]), torch.Size([4096, 3]))
In [ ]:
# Training the network to recreate the image
def train_normalnet(net, lr, X, Y, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, Y)
        loss.backward()
        optimizer.step()
        if verbose and epoch % 100 == 0:
            print(f"Epoch {epoch} loss: {loss.item():.6f}")
    return loss.item()
In [ ]:
import time
n_iter = 2000

start_time = time.time()

train_normalnet(nns["img0"]["relu"], lr=3e-4, X=img0_X_sh, Y=img0_Y_sh, epochs=n_iter)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.330164
Epoch 100 loss: 0.072692
Epoch 200 loss: 0.071209
Epoch 300 loss: 0.070467
Epoch 400 loss: 0.069878
Epoch 500 loss: 0.069198
Epoch 600 loss: 0.068229
Epoch 700 loss: 0.066728
Epoch 800 loss: 0.064489
Epoch 900 loss: 0.061017
Epoch 1000 loss: 0.055909
Epoch 1100 loss: 0.050226
Epoch 1200 loss: 0.045593
Epoch 1300 loss: 0.042071
Epoch 1400 loss: 0.039665
Epoch 1500 loss: 0.037875
Epoch 1600 loss: 0.036552
Epoch 1700 loss: 0.035374
Epoch 1800 loss: 0.034569
Epoch 1900 loss: 0.033748
Training time: 3.97 seconds
In [ ]:
output = nns["img0"]["relu"](img0_X)#.detach().cpu().numpy()
print(output.shape)
num_channels, height, width = img0.shape

output = output.reshape(num_channels, height, width)
output = output.permute(1, 2, 0)

fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(output.detach().cpu())
ax.set_title("Reconstructed Relu Image")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([4096, 3])
Out[ ]:
Text(0.5, 1.0, 'Reconstructed Relu Image')
In [ ]:
import time
n_iter = 2000

start_time = time.time()

train_normalnet(nns["img0"]["sin"], lr=3e-4, X=img0_X_sh, Y=img0_Y_sh, epochs=n_iter)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.393856
Epoch 100 loss: 0.012778
Epoch 200 loss: 0.007427
Epoch 300 loss: 0.005406
Epoch 400 loss: 0.004122
Epoch 500 loss: 0.003208
Epoch 600 loss: 0.002511
Epoch 700 loss: 0.001977
Epoch 800 loss: 0.001601
Epoch 900 loss: 0.001339
Epoch 1000 loss: 0.001124
Epoch 1100 loss: 0.000962
Epoch 1200 loss: 0.000838
Epoch 1300 loss: 0.000752
Epoch 1400 loss: 0.000663
Epoch 1500 loss: 0.000605
Epoch 1600 loss: 0.000568
Epoch 1700 loss: 0.000524
Epoch 1800 loss: 0.000518
Epoch 1900 loss: 0.000470
Training time: 3.48 seconds
In [ ]:
output = nns["img0"]["sin"](img0_X)#.detach().cpu().numpy()
print(output.shape)
num_channels, height, width = img0.shape

output = output.reshape(num_channels, height, width)
output = output.permute(1, 2, 0)

fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(output.detach().cpu())
ax.set_title("Reconstructed Image Siren")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([4096, 3])
Out[ ]:
Text(0.5, 1.0, 'Reconstructed Image Siren')

Making a Hypernet and training it on only 1 image¶

In [ ]:
# Here we are looking at the composition of our target net, The hypernet should be able to
# produce those many features for these parameters

try:
    from tabulate import tabulate
except:
    %pip install tabulate
    from tabulate import tabulate

model = nns["img0"]["sin"]

table_data = []

total_params = 0
start = 0
start_end_mapping = {}
for name, param in model.named_parameters():
    param_count = torch.prod(torch.tensor(param.shape)).item()
    total_params += param_count
    end = total_params
    table_data.append([name, param.shape, param_count, start, end])
    start_end_mapping[name] = (start, end)
    start = end

print(tabulate(table_data, headers=["Layer Name", "Shape", "Parameter Count", "Start Index", "End Index"]))
print(f"Total number of parameters: {total_params}")
Layer Name    Shape                   Parameter Count    Start Index    End Index
------------  --------------------  -----------------  -------------  -----------
fc1.weight    torch.Size([64, 2])                 128              0          128
fc1.bias      torch.Size([64])                     64            128          192
fc2.weight    torch.Size([64, 64])               4096            192         4288
fc2.bias      torch.Size([64])                     64           4288         4352
fc3.weight    torch.Size([64, 64])               4096           4352         8448
fc3.bias      torch.Size([64])                     64           8448         8512
fc5.weight    torch.Size([3, 64])                 192           8512         8704
fc5.bias      torch.Size([3])                       3           8704         8707
Total number of parameters: 8707
In [ ]:
# Hypernet class

total_params=8707
ss=256

class HyperNet(nn.Module):
    def __init__(self, num_layers=5, num_neurons=256, activation=torch.sin, n_out=3):
        super().__init__()
        self.activation = activation
        self.n_out = total_params
        self.fc1 = nn.Linear(5, ss)
        self.fc2 = nn.Linear(ss, ss)
        self.fc3 = nn.Linear(ss, total_params)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        return self.fc3(x)
In [ ]:
hp = HyperNet().to(device)

out_hp = hp(torch.rand(10, 5).to(device))
print(out_hp.shape)

weights_flattened  = out_hp.mean(dim=0)
print(weights_flattened.shape)

print(hp(torch.rand(2000, 5).to(device)).shape)
torch.Size([10, 8707])
torch.Size([8707])
torch.Size([2000, 8707])
In [ ]:
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape

img0_hyper_data = []
datano = 1 #instead of 20 images
context_precent = 100 # increased the context size to 10 percent

for i in range(datano):
    sh_indexi = torch.randperm(img0_X.shape[0])
    cont_img0_X_shi = img0_X[sh_indexi][0:int(len(img0_X)*context_precent/100)]
    cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X)*context_precent/100)]
    context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)

    train_img0_X_shi = img0_X
    train_img0_Y_shi = img0_Y

    datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
    img0_hyper_data.append(datas)
In [ ]:
img0_hyper_data = []
img0_hyper_data.append([torch.cat((img0_X, img0_Y), dim = 1), img0_X, img0_Y])
In [ ]:
img0_hyper_data[0][0].shape, img0_hyper_data[0][1].shape, img0_hyper_data[0][2].shape
Out[ ]:
(torch.Size([4096, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
params2 = hp(img0_hyper_data[0][0]).mean(dim=0)
params2.shape, img0_hyper_data[0][0].shape
Out[ ]:
(torch.Size([8707]), torch.Size([4096, 5]))
In [ ]:
model = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device)
model
Out[ ]:
NN(
  (fc1): Linear(in_features=2, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=3, bias=True)
)
In [ ]:
# Installing the astra library
! pip install astra-lib
Collecting astra-lib
  Downloading astra-lib-0.0.2.tar.gz (136 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 136.6/136.6 kB 3.3 MB/s eta 0:00:00
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from astra-lib) (1.23.5)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from astra-lib) (1.5.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from astra-lib) (3.7.1)
Requirement already satisfied: xarray in /usr/local/lib/python3.10/dist-packages (from astra-lib) (2023.7.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from astra-lib) (4.66.1)
Collecting optree (from astra-lib)
  Downloading optree-0.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (286 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 286.8/286.8 kB 15.3 MB/s eta 0:00:00
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (4.44.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (23.2)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (2.8.2)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree->astra-lib) (4.5.0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->astra-lib) (2023.3.post1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->astra-lib) (1.16.0)
Building wheels for collected packages: astra-lib
  Building wheel for astra-lib (pyproject.toml) ... done
  Created wheel for astra-lib: filename=astra_lib-0.0.2-py3-none-any.whl size=21845 sha256=21853f35730a094084175db7468b414103a824ca437ec06c5b23383555786d99
  Stored in directory: /root/.cache/pip/wheels/db/a6/8d/73931c696ff5c17a3364e2962cf8680790ae07a3e8aa55587b
Successfully built astra-lib
Installing collected packages: optree, astra-lib
Successfully installed astra-lib-0.0.2 optree-0.10.0
In [ ]:
import astra
from astra.torch.utils import ravel_pytree
flat_weights, unravel_fn = ravel_pytree(dict(model.named_parameters()))
print(flat_weights.shape)
torch.Size([8707])
In [ ]:
unravel_fn
Out[ ]:
<function astra.torch.utils.ravel_pytree.<locals>.unravel_function(flat_params)>
In [ ]:
# Training the network to recreate the image
torch.manual_seed(42)

# targetnet = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device) # Target Net
# new_dict = targetnet.state_dict()

def train_hypernet(hypernet, target_net, lr, hyper_data, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """
    loss_list = []
    datano = len(hyper_data)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)

    for epoch in range(epochs):
        running_loss = 0.0
        for i in range(datano):
            context_data, train_img_X_shi, train_img_Y_shi = hyper_data[i]

            optimizer.zero_grad()
            output = hypernet(context_data)
            params = output.mean(dim=0)

            flat_weights, unravel_fn = ravel_pytree(dict(target_net.named_parameters()))
            new_dict = unravel_fn(params)

            # new_dict.update(unravel_fn(params))

            outputs = torch.func.functional_call(target_net, new_dict, train_img_X_shi)

            loss = criterion(outputs, train_img_Y_shi)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_list.append(running_loss)

        if verbose and epoch % 50 == 0:
            # Values are not chaninging ? why?
            # print(hypernet.state_dict()["fc1.weight"][0:2, 0:2])
            print(f"Epoch {epoch} loss: {running_loss/datano:.6f}")

    return loss_list
In [ ]:
import time
torch.manual_seed(42)

start_time = time.time()
hp1 = HyperNet(activation = nn.ReLU() ).to(device)
targetnet1 = NN().to(device)
loss_list = train_hypernet(hp1, targetnet1, lr = 3e-4, hyper_data = img0_hyper_data, epochs = 15000, verbose=True)
# reduced learning rate
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.403650
Epoch 50 loss: 0.071302
Epoch 100 loss: 0.071194
Epoch 150 loss: 0.071189
Epoch 200 loss: 0.071183
Epoch 250 loss: 0.071176
Epoch 300 loss: 0.071166
Epoch 350 loss: 0.071154
Epoch 400 loss: 0.071146
Epoch 450 loss: 0.071144
Epoch 500 loss: 0.071142
Epoch 550 loss: 0.071140
Epoch 600 loss: 0.071138
Epoch 650 loss: 0.071136
Epoch 700 loss: 0.071135
Epoch 750 loss: 0.071133
Epoch 800 loss: 0.071132
Epoch 850 loss: 0.071130
Epoch 900 loss: 0.071129
Epoch 950 loss: 0.071127
Epoch 1000 loss: 0.071126
Epoch 1050 loss: 0.071128
Epoch 1100 loss: 0.071125
Epoch 1150 loss: 0.071124
Epoch 1200 loss: 0.071129
Epoch 1250 loss: 0.071126
Epoch 1300 loss: 0.071125
Epoch 1350 loss: 0.071124
Epoch 1400 loss: 0.071123
Epoch 1450 loss: 0.071122
Epoch 1500 loss: 0.071511
Epoch 1550 loss: 0.071131
Epoch 1600 loss: 0.071128
Epoch 1650 loss: 0.071126
Epoch 1700 loss: 0.071125
Epoch 1750 loss: 0.071124
Epoch 1800 loss: 0.071123
Epoch 1850 loss: 0.071122
Epoch 1900 loss: 0.071121
Epoch 1950 loss: 0.071120
Epoch 2000 loss: 0.071119
Epoch 2050 loss: 0.071293
Epoch 2100 loss: 0.071133
Epoch 2150 loss: 0.071130
Epoch 2200 loss: 0.071128
Epoch 2250 loss: 0.071126
Epoch 2300 loss: 0.071125
Epoch 2350 loss: 0.071124
Epoch 2400 loss: 0.071123
Epoch 2450 loss: 0.071122
Epoch 2500 loss: 0.071121
Epoch 2550 loss: 0.071120
Epoch 2600 loss: 0.071119
Epoch 2650 loss: 0.071118
Epoch 2700 loss: 0.071116
Epoch 2750 loss: 0.071143
Epoch 2800 loss: 0.071109
Epoch 2850 loss: 0.070792
Epoch 2900 loss: 0.072253
Epoch 2950 loss: 0.070422
Epoch 3000 loss: 0.070312
Epoch 3050 loss: 0.072265
Epoch 3100 loss: 0.070852
Epoch 3150 loss: 0.070409
Epoch 3200 loss: 0.070321
Epoch 3250 loss: 0.070825
Epoch 3300 loss: 0.070360
Epoch 3350 loss: 0.070298
Epoch 3400 loss: 0.070170
Epoch 3450 loss: 0.069760
Epoch 3500 loss: 0.069480
Epoch 3550 loss: 0.069419
Epoch 3600 loss: 0.069325
Epoch 3650 loss: 0.069534
Epoch 3700 loss: 0.068746
Epoch 3750 loss: 0.068677
Epoch 3800 loss: 0.068613
Epoch 3850 loss: 0.068559
Epoch 3900 loss: 0.069640
Epoch 3950 loss: 0.068416
Epoch 4000 loss: 0.068115
Epoch 4050 loss: 0.067895
Epoch 4100 loss: 0.067805
Epoch 4150 loss: 0.066609
Epoch 4200 loss: 0.057200
Epoch 4250 loss: 0.053656
Epoch 4300 loss: 0.051547
Epoch 4350 loss: 0.049484
Epoch 4400 loss: 0.047249
Epoch 4450 loss: 0.039190
Epoch 4500 loss: 0.035528
Epoch 4550 loss: 0.035988
Epoch 4600 loss: 0.031441
Epoch 4650 loss: 0.030642
Epoch 4700 loss: 0.029882
Epoch 4750 loss: 0.028969
Epoch 4800 loss: 0.028517
Epoch 4850 loss: 0.027376
Epoch 4900 loss: 0.028707
Epoch 4950 loss: 0.024760
Epoch 5000 loss: 0.024197
Epoch 5050 loss: 0.023947
Epoch 5100 loss: 0.024595
Epoch 5150 loss: 0.022410
Epoch 5200 loss: 0.022466
Epoch 5250 loss: 0.021812
Epoch 5300 loss: 0.020312
Epoch 5350 loss: 0.021673
Epoch 5400 loss: 0.019209
Epoch 5450 loss: 0.018692
Epoch 5500 loss: 0.018397
Epoch 5550 loss: 0.018506
Epoch 5600 loss: 0.017837
Epoch 5650 loss: 0.017525
Epoch 5700 loss: 0.017342
Epoch 5750 loss: 0.017501
Epoch 5800 loss: 0.017087
Epoch 5850 loss: 0.016972
Epoch 5900 loss: 0.016583
Epoch 5950 loss: 0.016422
Epoch 6000 loss: 0.016368
Epoch 6050 loss: 0.016281
Epoch 6100 loss: 0.016753
Epoch 6150 loss: 0.016152
Epoch 6200 loss: 0.016031
Epoch 6250 loss: 0.017693
Epoch 6300 loss: 0.016301
Epoch 6350 loss: 0.016454
Epoch 6400 loss: 0.017366
Epoch 6450 loss: 0.015956
Epoch 6500 loss: 0.015282
Epoch 6550 loss: 0.015224
Epoch 6600 loss: 0.015460
Epoch 6650 loss: 0.015093
Epoch 6700 loss: 0.015143
Epoch 6750 loss: 0.015584
Epoch 6800 loss: 0.014723
Epoch 6850 loss: 0.014729
Epoch 6900 loss: 0.014920
Epoch 6950 loss: 0.014756
Epoch 7000 loss: 0.014658
Epoch 7050 loss: 0.014436
Epoch 7100 loss: 0.015304
Epoch 7150 loss: 0.014708
Epoch 7200 loss: 0.014184
Epoch 7250 loss: 0.014117
Epoch 7300 loss: 0.014174
Epoch 7350 loss: 0.013970
Epoch 7400 loss: 0.014577
Epoch 7450 loss: 0.014161
Epoch 7500 loss: 0.013952
Epoch 7550 loss: 0.013838
Epoch 7600 loss: 0.013753
Epoch 7650 loss: 0.013718
Epoch 7700 loss: 0.014550
Epoch 7750 loss: 0.014340
Epoch 7800 loss: 0.013824
Epoch 7850 loss: 0.013968
Epoch 7900 loss: 0.015152
Epoch 7950 loss: 0.013289
Epoch 8000 loss: 0.013215
Epoch 8050 loss: 0.013265
Epoch 8100 loss: 0.013107
Epoch 8150 loss: 0.013669
Epoch 8200 loss: 0.013174
Epoch 8250 loss: 0.012806
Epoch 8300 loss: 0.012764
Epoch 8350 loss: 0.012953
Epoch 8400 loss: 0.013129
Epoch 8450 loss: 0.012777
Epoch 8500 loss: 0.012705
Epoch 8550 loss: 0.014406
Epoch 8600 loss: 0.013657
Epoch 8650 loss: 0.012415
Epoch 8700 loss: 0.012482
Epoch 8750 loss: 0.012352
Epoch 8800 loss: 0.012332
Epoch 8850 loss: 0.013093
Epoch 8900 loss: 0.012409
Epoch 8950 loss: 0.012751
Epoch 9000 loss: 0.012107
Epoch 9050 loss: 0.011987
Epoch 9100 loss: 0.011990
Epoch 9150 loss: 0.011914
Epoch 9200 loss: 0.011947
Epoch 9250 loss: 0.011958
Epoch 9300 loss: 0.011884
Epoch 9350 loss: 0.011888
Epoch 9400 loss: 0.011983
Epoch 9450 loss: 0.011705
Epoch 9500 loss: 0.011694
Epoch 9550 loss: 0.012067
Epoch 9600 loss: 0.011904
Epoch 9650 loss: 0.011582
Epoch 9700 loss: 0.012040
Epoch 9750 loss: 0.011692
Epoch 9800 loss: 0.011438
Epoch 9850 loss: 0.011446
Epoch 9900 loss: 0.011637
Epoch 9950 loss: 0.011481
Epoch 10000 loss: 0.011284
Epoch 10050 loss: 0.012001
Epoch 10100 loss: 0.011459
Epoch 10150 loss: 0.011424
Epoch 10200 loss: 0.011292
Epoch 10250 loss: 0.011367
Epoch 10300 loss: 0.011308
Epoch 10350 loss: 0.011074
Epoch 10400 loss: 0.011024
Epoch 10450 loss: 0.011534
Epoch 10500 loss: 0.011289
Epoch 10550 loss: 0.011159
Epoch 10600 loss: 0.011577
Epoch 10650 loss: 0.011358
Epoch 10700 loss: 0.010867
Epoch 10750 loss: 0.011331
Epoch 10800 loss: 0.010989
Epoch 10850 loss: 0.010952
Epoch 10900 loss: 0.010838
Epoch 10950 loss: 0.010806
Epoch 11000 loss: 0.011251
Epoch 11050 loss: 0.010741
Epoch 11100 loss: 0.010635
Epoch 11150 loss: 0.010690
Epoch 11200 loss: 0.010726
Epoch 11250 loss: 0.011051
Epoch 11300 loss: 0.011013
Epoch 11350 loss: 0.011148
Epoch 11400 loss: 0.010769
Epoch 11450 loss: 0.010810
Epoch 11500 loss: 0.010482
Epoch 11550 loss: 0.010950
Epoch 11600 loss: 0.010358
Epoch 11650 loss: 0.010617
Epoch 11700 loss: 0.010756
Epoch 11750 loss: 0.010459
Epoch 11800 loss: 0.010271
Epoch 11850 loss: 0.010931
Epoch 11900 loss: 0.010484
Epoch 11950 loss: 0.010253
Epoch 12000 loss: 0.010425
Epoch 12050 loss: 0.010247
Epoch 12100 loss: 0.010368
Epoch 12150 loss: 0.010326
Epoch 12200 loss: 0.010047
Epoch 12250 loss: 0.010260
Epoch 12300 loss: 0.010120
Epoch 12350 loss: 0.010563
Epoch 12400 loss: 0.010085
Epoch 12450 loss: 0.009982
Epoch 12500 loss: 0.010174
Epoch 12550 loss: 0.010021
Epoch 12600 loss: 0.009996
Epoch 12650 loss: 0.010109
Epoch 12700 loss: 0.009951
Epoch 12750 loss: 0.009865
Epoch 12800 loss: 0.010103
Epoch 12850 loss: 0.009926
Epoch 12900 loss: 0.010144
Epoch 12950 loss: 0.009936
Epoch 13000 loss: 0.010185
Epoch 13050 loss: 0.010106
Epoch 13100 loss: 0.010005
Epoch 13150 loss: 0.009894
Epoch 13200 loss: 0.009666
Epoch 13250 loss: 0.009817
Epoch 13300 loss: 0.009859
Epoch 13350 loss: 0.009615
Epoch 13400 loss: 0.009594
Epoch 13450 loss: 0.009915
Epoch 13500 loss: 0.009762
Epoch 13550 loss: 0.009700
Epoch 13600 loss: 0.009972
Epoch 13650 loss: 0.010076
Epoch 13700 loss: 0.009769
Epoch 13750 loss: 0.009585
Epoch 13800 loss: 0.009591
Epoch 13850 loss: 0.009793
Epoch 13900 loss: 0.009485
Epoch 13950 loss: 0.009821
Epoch 14000 loss: 0.009819
Epoch 14050 loss: 0.009761
Epoch 14100 loss: 0.009630
Epoch 14150 loss: 0.009575
Epoch 14200 loss: 0.009442
Epoch 14250 loss: 0.010140
Epoch 14300 loss: 0.009336
Epoch 14350 loss: 0.009232
Epoch 14400 loss: 0.010051
Epoch 14450 loss: 0.009685
Epoch 14500 loss: 0.009288
Epoch 14550 loss: 0.009366
Epoch 14600 loss: 0.009287
Epoch 14650 loss: 0.009389
Epoch 14700 loss: 0.009396
Epoch 14750 loss: 0.009255
Epoch 14800 loss: 0.009340
Epoch 14850 loss: 0.009224
Epoch 14900 loss: 0.009066
Epoch 14950 loss: 0.009403
Training time: 286.73 seconds
In [ ]:
loss_list2 = train_hypernet(hp1, targetnet1, lr = 3e-4, hyper_data = img0_hyper_data, epochs = 10000, verbose=True)
Epoch 0 loss: 0.009098
Epoch 50 loss: 0.013128
Epoch 100 loss: 0.009980
Epoch 150 loss: 0.009274
Epoch 200 loss: 0.009523
Epoch 250 loss: 0.009007
Epoch 300 loss: 0.008988
Epoch 350 loss: 0.009015
Epoch 400 loss: 0.009092
Epoch 450 loss: 0.009587
Epoch 500 loss: 0.008949
Epoch 550 loss: 0.009135
Epoch 600 loss: 0.009011
Epoch 650 loss: 0.008903
Epoch 700 loss: 0.008935
Epoch 750 loss: 0.009252
Epoch 800 loss: 0.009137
Epoch 850 loss: 0.008866
Epoch 900 loss: 0.009325
Epoch 950 loss: 0.009051
Epoch 1000 loss: 0.008735
Epoch 1050 loss: 0.008841
Epoch 1100 loss: 0.009356
Epoch 1150 loss: 0.008910
Epoch 1200 loss: 0.008656
Epoch 1250 loss: 0.008877
Epoch 1300 loss: 0.008681
Epoch 1350 loss: 0.008711
Epoch 1400 loss: 0.009730
Epoch 1450 loss: 0.009411
Epoch 1500 loss: 0.008878
Epoch 1550 loss: 0.008693
Epoch 1600 loss: 0.009151
Epoch 1650 loss: 0.008871
Epoch 1700 loss: 0.008522
Epoch 1750 loss: 0.008817
Epoch 1800 loss: 0.008764
Epoch 1850 loss: 0.008510
Epoch 1900 loss: 0.008537
Epoch 1950 loss: 0.008460
Epoch 2000 loss: 0.008674
Epoch 2050 loss: 0.009251
Epoch 2100 loss: 0.008363
Epoch 2150 loss: 0.008522
Epoch 2200 loss: 0.008513
Epoch 2250 loss: 0.008552
Epoch 2300 loss: 0.008843
Epoch 2350 loss: 0.008299
Epoch 2400 loss: 0.008840
Epoch 2450 loss: 0.008896
Epoch 2500 loss: 0.008771
Epoch 2550 loss: 0.008288
Epoch 2600 loss: 0.008335
Epoch 2650 loss: 0.008693
Epoch 2700 loss: 0.008339
Epoch 2750 loss: 0.008195
Epoch 2800 loss: 0.008204
Epoch 2850 loss: 0.008378
Epoch 2900 loss: 0.008293
Epoch 2950 loss: 0.008314
Epoch 3000 loss: 0.008148
Epoch 3050 loss: 0.008174
Epoch 3100 loss: 0.008599
Epoch 3150 loss: 0.008363
Epoch 3200 loss: 0.008353
Epoch 3250 loss: 0.008135
Epoch 3300 loss: 0.008423
Epoch 3350 loss: 0.008360
Epoch 3400 loss: 0.008199
Epoch 3450 loss: 0.008243
Epoch 3500 loss: 0.008171
Epoch 3550 loss: 0.008225
Epoch 3600 loss: 0.008203
Epoch 3650 loss: 0.008148
Epoch 3700 loss: 0.008174
Epoch 3750 loss: 0.008039
Epoch 3800 loss: 0.008076
Epoch 3850 loss: 0.008205
Epoch 3900 loss: 0.008367
Epoch 3950 loss: 0.008299
Epoch 4000 loss: 0.008120
Epoch 4050 loss: 0.008021
Epoch 4100 loss: 0.008654
Epoch 4150 loss: 0.008405
Epoch 4200 loss: 0.008289
Epoch 4250 loss: 0.008199
Epoch 4300 loss: 0.008161
Epoch 4350 loss: 0.008186
Epoch 4400 loss: 0.007934
Epoch 4450 loss: 0.008202
Epoch 4500 loss: 0.008064
Epoch 4550 loss: 0.007923
Epoch 4600 loss: 0.007967
Epoch 4650 loss: 0.008116
Epoch 4700 loss: 0.007972
Epoch 4750 loss: 0.008034
Epoch 4800 loss: 0.008200
Epoch 4850 loss: 0.008282
Epoch 4900 loss: 0.007968
Epoch 4950 loss: 0.007777
Epoch 5000 loss: 0.007802
Epoch 5050 loss: 0.007816
Epoch 5100 loss: 0.008060
Epoch 5150 loss: 0.008199
Epoch 5200 loss: 0.007822
Epoch 5250 loss: 0.007950
Epoch 5300 loss: 0.008030
Epoch 5350 loss: 0.007862
Epoch 5400 loss: 0.007946
Epoch 5450 loss: 0.007973
Epoch 5500 loss: 0.007721
Epoch 5550 loss: 0.007928
Epoch 5600 loss: 0.007797
Epoch 5650 loss: 0.007869
Epoch 5700 loss: 0.008315
Epoch 5750 loss: 0.007892
Epoch 5800 loss: 0.007714
Epoch 5850 loss: 0.008180
Epoch 5900 loss: 0.007852
Epoch 5950 loss: 0.007752
Epoch 6000 loss: 0.008154
Epoch 6050 loss: 0.007822
Epoch 6100 loss: 0.007803
Epoch 6150 loss: 0.007877
Epoch 6200 loss: 0.007957
Epoch 6250 loss: 0.008377
Epoch 6300 loss: 0.007943
Epoch 6350 loss: 0.007683
Epoch 6400 loss: 0.007771
Epoch 6450 loss: 0.007806
Epoch 6500 loss: 0.007923
Epoch 6550 loss: 0.007985
Epoch 6600 loss: 0.007675
Epoch 6650 loss: 0.007631
Epoch 6700 loss: 0.007656
Epoch 6750 loss: 0.008071
Epoch 6800 loss: 0.007907
Epoch 6850 loss: 0.007681
Epoch 6900 loss: 0.008560
Epoch 6950 loss: 0.007597
Epoch 7000 loss: 0.007686
Epoch 7050 loss: 0.007675
Epoch 7100 loss: 0.007700
Epoch 7150 loss: 0.007716
Epoch 7200 loss: 0.008152
Epoch 7250 loss: 0.007822
Epoch 7300 loss: 0.007648
Epoch 7350 loss: 0.007931
Epoch 7400 loss: 0.007609
Epoch 7450 loss: 0.007668
Epoch 7500 loss: 0.007560
Epoch 7550 loss: 0.007940
Epoch 7600 loss: 0.007523
Epoch 7650 loss: 0.007659
Epoch 7700 loss: 0.007517
Epoch 7750 loss: 0.007579
Epoch 7800 loss: 0.008199
Epoch 7850 loss: 0.007650
Epoch 7900 loss: 0.007635
Epoch 7950 loss: 0.007524
Epoch 8000 loss: 0.007621
Epoch 8050 loss: 0.007606
Epoch 8100 loss: 0.007712
Epoch 8150 loss: 0.007747
Epoch 8200 loss: 0.007711
Epoch 8250 loss: 0.007574
Epoch 8300 loss: 0.009379
Epoch 8350 loss: 0.007912
Epoch 8400 loss: 0.007517
Epoch 8450 loss: 0.007770
Epoch 8500 loss: 0.007442
Epoch 8550 loss: 0.007595
Epoch 8600 loss: 0.007417
Epoch 8650 loss: 0.007445
Epoch 8700 loss: 0.007657
Epoch 8750 loss: 0.007604
Epoch 8800 loss: 0.007451
Epoch 8850 loss: 0.007435
Epoch 8900 loss: 0.007390
Epoch 8950 loss: 0.008437
Epoch 9000 loss: 0.007788
Epoch 9050 loss: 0.007606
Epoch 9100 loss: 0.007656
Epoch 9150 loss: 0.007530
Epoch 9200 loss: 0.007515
Epoch 9250 loss: 0.007492
Epoch 9300 loss: 0.007468
Epoch 9350 loss: 0.007586
Epoch 9400 loss: 0.007702
Epoch 9450 loss: 0.007509
Epoch 9500 loss: 0.007505
Epoch 9550 loss: 0.007531
Epoch 9600 loss: 0.007483
Epoch 9650 loss: 0.007448
Epoch 9700 loss: 0.007610
Epoch 9750 loss: 0.007565
Epoch 9800 loss: 0.007536
Epoch 9850 loss: 0.007473
Epoch 9900 loss: 0.007490
Epoch 9950 loss: 0.008649
In [ ]:
loss_list.extend(loss_list2)
plt.plot(loss_list[:25000])
plt.xlabel("Epochs")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
torch.save(hp1.state_dict(), "img0_hypernet.pt")
In [ ]:
def plot_reconstructed_and_original_image(original_img, hypernet, targetnet, X, context, title=""):
    """
    original_img = Original image
    Hypernet = hypernet
    targetnet = targetnet
    X = the full scaled image X
    context = (n,5) shaped context
    """
    # num_channels, height, width = original_img.shape

    num_channels, height, width = original_img.shape

    with torch.no_grad():
        params = hypernet(context).mean(dim=0)

        flat_weights,unravel_fn= ravel_pytree(dict(targetnet.named_parameters()))

        parameter_dictionary = unravel_fn(params)

        outputs = torch.func.functional_call(targetnet, parameter_dictionary, X)
        print(output.shape)
        outputs = outputs.reshape(num_channels, height, width)
        outputs = outputs.permute(1, 2, 0)

    fig = plt.figure(figsize=(8, 6))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    ax0.imshow(output.detach().cpu().numpy())
    ax0.set_title("Reconstructed Image")


    ax1.imshow(original_img.cpu().permute(1, 2, 0))
    ax1.set_title("Original Image")

    for a in [ax0, ax1]:
        a.axis("off")


    fig.suptitle(title, y=0.9)
    plt.tight_layout()
In [ ]:
context = torch.cat((img0_X, img0_Y), dim = 1)
plot_reconstructed_and_original_image(img0, hp1, targetnet1, img0_X, context, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])

Training 5 Batches of 1 image¶

In [ ]:
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape

img0_hyper_data2 = []
datano = 5 #instead of 20 images
context_precent_start = 10 # increased the context size to 10 percent
context_percent_end = 80

for i in range(datano):
    contextp = context_precent_start + (i/datano)*(context_percent_end-context_precent_start)
    sh_indexi = torch.randperm(img0_X.shape[0])
    cont_img0_X_shi = img0_X[sh_indexi][0:int(len(img0_X)*contextp/100)]
    cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X)*contextp/100)]
    context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)

    train_img0_X_shi = img0_X
    train_img0_Y_shi = img0_Y

    datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
    img0_hyper_data2.append(datas)
In [ ]:
img0_hyper_data2[1][0].shape, img0_hyper_data2[1][1].shape, img0_hyper_data2[1][1].shape
Out[ ]:
(torch.Size([983, 5]), torch.Size([4096, 2]), torch.Size([4096, 2]))
In [ ]:
import time
torch.manual_seed(42)

start_time = time.time()
hp_img2 = HyperNet(activation = nn.ReLU() ).to(device)
targetnet_img2 = NN().to(device)
loss_list_img2 = train_hypernet(hp_img2, targetnet_img2, lr = 3e-4, hyper_data = img0_hyper_data2, epochs = 10000, verbose=True)
# reduced learning rate
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.192977
Epoch 50 loss: 0.071178
Epoch 100 loss: 0.071160
Epoch 150 loss: 0.071141
Epoch 200 loss: 0.071152
Epoch 250 loss: 0.071129
Epoch 300 loss: 0.071131
Epoch 350 loss: 0.071149
Epoch 400 loss: 0.071139
Epoch 450 loss: 0.071147
Epoch 500 loss: 0.071271
Epoch 550 loss: 0.071155
Epoch 600 loss: 0.069042
Epoch 650 loss: 0.068657
Epoch 700 loss: 0.055331
Epoch 750 loss: 0.036802
Epoch 800 loss: 0.034901
Epoch 850 loss: 0.028884
Epoch 900 loss: 0.026838
Epoch 950 loss: 0.025977
Epoch 1000 loss: 0.026541
Epoch 1050 loss: 0.022476
Epoch 1100 loss: 0.021346
Epoch 1150 loss: 0.019726
Epoch 1200 loss: 0.021250
Epoch 1250 loss: 0.017727
Epoch 1300 loss: 0.017597
Epoch 1350 loss: 0.017356
Epoch 1400 loss: 0.017615
Epoch 1450 loss: 0.017500
Epoch 1500 loss: 0.015576
Epoch 1550 loss: 0.015585
Epoch 1600 loss: 0.014925
Epoch 1650 loss: 0.014785
Epoch 1700 loss: 0.014516
Epoch 1750 loss: 0.014382
Epoch 1800 loss: 0.014247
Epoch 1850 loss: 0.013916
Epoch 1900 loss: 0.013573
Epoch 1950 loss: 0.013585
Epoch 2000 loss: 0.013410
Epoch 2050 loss: 0.012568
Epoch 2100 loss: 0.012534
Epoch 2150 loss: 0.012586
Epoch 2200 loss: 0.012624
Epoch 2250 loss: 0.012460
Epoch 2300 loss: 0.012309
Epoch 2350 loss: 0.011867
Epoch 2400 loss: 0.012249
Epoch 2450 loss: 0.012101
Epoch 2500 loss: 0.011160
Epoch 2550 loss: 0.011214
Epoch 2600 loss: 0.011539
Epoch 2650 loss: 0.011330
Epoch 2700 loss: 0.011088
Epoch 2750 loss: 0.010844
Epoch 2800 loss: 0.010697
Epoch 2850 loss: 0.010616
Epoch 2900 loss: 0.010692
Epoch 2950 loss: 0.010700
Epoch 3000 loss: 0.010689
Epoch 3050 loss: 0.010260
Epoch 3100 loss: 0.010396
Epoch 3150 loss: 0.009796
Epoch 3200 loss: 0.009536
Epoch 3250 loss: 0.010086
Epoch 3300 loss: 0.009856
Epoch 3350 loss: 0.009641
Epoch 3400 loss: 0.009877
Epoch 3450 loss: 0.009714
Epoch 3500 loss: 0.009413
Epoch 3550 loss: 0.009772
Epoch 3600 loss: 0.009389
Epoch 3650 loss: 0.009169
Epoch 3700 loss: 0.009201
Epoch 3750 loss: 0.009153
Epoch 3800 loss: 0.009148
Epoch 3850 loss: 0.008968
Epoch 3900 loss: 0.009011
Epoch 3950 loss: 0.009109
Epoch 4000 loss: 0.008994
Epoch 4050 loss: 0.008814
Epoch 4100 loss: 0.009027
Epoch 4150 loss: 0.008648
Epoch 4200 loss: 0.009128
Epoch 4250 loss: 0.008316
Epoch 4300 loss: 0.008570
Epoch 4350 loss: 0.010499
Epoch 4400 loss: 0.008047
Epoch 4450 loss: 0.008221
Epoch 4500 loss: 0.008445
Epoch 4550 loss: 0.008254
Epoch 4600 loss: 0.008390
Epoch 4650 loss: 0.008111
Epoch 4700 loss: 0.009442
Epoch 4750 loss: 0.008505
Epoch 4800 loss: 0.008337
Epoch 4850 loss: 0.008782
Epoch 4900 loss: 0.008167
Epoch 4950 loss: 0.008045
Epoch 5000 loss: 0.008011
Epoch 5050 loss: 0.008032
Epoch 5100 loss: 0.008029
Epoch 5150 loss: 0.007824
Epoch 5200 loss: 0.007987
Epoch 5250 loss: 0.008000
Epoch 5300 loss: 0.008100
Epoch 5350 loss: 0.007967
Epoch 5400 loss: 0.007496
Epoch 5450 loss: 0.007570
Epoch 5500 loss: 0.007435
Epoch 5550 loss: 0.007867
Epoch 5600 loss: 0.007734
Epoch 5650 loss: 0.008091
Epoch 5700 loss: 0.007538
Epoch 5750 loss: 0.007734
Epoch 5800 loss: 0.007423
Epoch 5850 loss: 0.007576
Epoch 5900 loss: 0.007737
Epoch 5950 loss: 0.007457
Epoch 6000 loss: 0.007553
Epoch 6050 loss: 0.007804
Epoch 6100 loss: 0.007460
Epoch 6150 loss: 0.007287
Epoch 6200 loss: 0.007043
Epoch 6250 loss: 0.007309
Epoch 6300 loss: 0.007235
Epoch 6350 loss: 0.007657
Epoch 6400 loss: 0.007053
Epoch 6450 loss: 0.007620
Epoch 6500 loss: 0.006970
Epoch 6550 loss: 0.007006
Epoch 6600 loss: 0.007215
Epoch 6650 loss: 0.007075
Epoch 6700 loss: 0.007141
Epoch 6750 loss: 0.006854
Epoch 6800 loss: 0.007412
Epoch 6850 loss: 0.007005
Epoch 6900 loss: 0.007194
Epoch 6950 loss: 0.006748
Epoch 7000 loss: 0.006982
Epoch 7050 loss: 0.006718
Epoch 7100 loss: 0.006766
Epoch 7150 loss: 0.006730
Epoch 7200 loss: 0.006824
Epoch 7250 loss: 0.007010
Epoch 7300 loss: 0.006894
Epoch 7350 loss: 0.006803
Epoch 7400 loss: 0.006902
Epoch 7450 loss: 0.006521
Epoch 7500 loss: 0.006476
Epoch 7550 loss: 0.006627
Epoch 7600 loss: 0.006803
Epoch 7650 loss: 0.007007
Epoch 7700 loss: 0.006755
Epoch 7750 loss: 0.006311
Epoch 7800 loss: 0.006439
Epoch 7850 loss: 0.006668
Epoch 7900 loss: 0.007098
Epoch 7950 loss: 0.007805
Epoch 8000 loss: 0.006405
Epoch 8050 loss: 0.006763
Epoch 8100 loss: 0.006640
Epoch 8150 loss: 0.006424
Epoch 8200 loss: 0.006426
Epoch 8250 loss: 0.006576
Epoch 8300 loss: 0.006390
Epoch 8350 loss: 0.006189
Epoch 8400 loss: 0.006715
Epoch 8450 loss: 0.006440
Epoch 8500 loss: 0.006311
Epoch 8550 loss: 0.006487
Epoch 8600 loss: 0.006274
Epoch 8650 loss: 0.006384
Epoch 8700 loss: 0.006170
Epoch 8750 loss: 0.006158
Epoch 8800 loss: 0.006276
Epoch 8850 loss: 0.006181
Epoch 8900 loss: 0.006423
Epoch 8950 loss: 0.006242
Epoch 9000 loss: 0.005993
Epoch 9050 loss: 0.006317
Epoch 9100 loss: 0.005988
Epoch 9150 loss: 0.006118
Epoch 9200 loss: 0.006159
Epoch 9250 loss: 0.006061
Epoch 9300 loss: 0.006200
Epoch 9350 loss: 0.006279
Epoch 9400 loss: 0.006120
Epoch 9450 loss: 0.006521
Epoch 9500 loss: 0.006148
Epoch 9550 loss: 0.006173
Epoch 9600 loss: 0.005929
Epoch 9650 loss: 0.006773
Epoch 9700 loss: 0.005978
Epoch 9750 loss: 0.005842
Epoch 9800 loss: 0.005882
Epoch 9850 loss: 0.006384
Epoch 9900 loss: 0.006089
Epoch 9950 loss: 0.006302
Training time: 396.95 seconds
In [ ]:
test_contextp = 1
sh_indexi = torch.randperm(img0_X.shape[0])
test_cont_img0_X = img0_X[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_cont_img0_Y = img0_Y[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_context_data = torch.cat((test_cont_img0_X, test_cont_img0_Y), dim = 1)
test_context_data.shape
Out[ ]:
torch.Size([40, 5])
In [ ]:
plot_reconstructed_and_original_image(img0, hp_img2, targetnet_img2, img0_X, test_context_data, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])

Images of 50 celebrities, 50 images each¶

In [ ]:
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape

no_of_images = 50

images_hyper_data = []
datano = 4 #instead of 20 images
context_precent_start = 10 # increased the context size to 10 percent
context_percent_end = 50

for celebno in range(no_of_images):
  image_X, image_Y, scaler_X = create_scaled_cmap(celeba_dataset[celebno], rt = False)
  for i in range(datano):
      contextp = context_precent_start + (i/datano)*(context_percent_end-context_precent_start)
      sh_indexi = torch.randperm(image_X.shape[0])
      cont_img0_X_shi = image_X[sh_indexi][0:int(len(img0_X)*contextp/100)]
      cont_img0_Y_shi = image_Y[sh_indexi][0:int(len(img0_X)*contextp/100)]
      context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)

      train_img0_X_shi = image_X
      train_img0_Y_shi = image_Y

      datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
      images_hyper_data.append(datas)
In [ ]:
len(images_hyper_data), images_hyper_data[0][0].shape, images_hyper_data[0][1].shape, images_hyper_data[0][2].shape
Out[ ]:
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
import time
torch.manual_seed(42)

start_time = time.time()
celeb_hp = HyperNet(activation = nn.ReLU() ).to(device)
celeb_targetnet = NN().to(device)
celeb_loss_list = train_hypernet(celeb_hp, celeb_targetnet, lr = 3e-4, hyper_data = images_hyper_data, epochs = 4500, verbose=True)
# reduced learning rate
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.072027
Epoch 50 loss: 0.060346
Epoch 100 loss: 0.060094
Epoch 150 loss: 0.059345
Epoch 200 loss: 0.059046
Epoch 250 loss: 0.057511
Epoch 300 loss: 0.057402
Epoch 350 loss: 0.056238
Epoch 400 loss: 0.055600
Epoch 450 loss: 0.055547
Epoch 500 loss: 0.055558
Epoch 550 loss: 0.054876
Epoch 600 loss: 0.054605
Epoch 650 loss: 0.053750
Epoch 700 loss: 0.052658
Epoch 750 loss: 0.052339
Epoch 800 loss: 0.051500
Epoch 850 loss: 0.051287
Epoch 900 loss: 0.050553
Epoch 950 loss: 0.050224
Epoch 1000 loss: 0.049445
Epoch 1050 loss: 0.049254
Epoch 1100 loss: 0.048603
Epoch 1150 loss: 0.048420
Epoch 1200 loss: 0.047918
Epoch 1250 loss: 0.047415
Epoch 1300 loss: 0.046563
Epoch 1350 loss: 0.045343
Epoch 1400 loss: 0.044586
Epoch 1450 loss: 0.043709
Epoch 1500 loss: 0.040924
Epoch 1550 loss: 0.040019
Epoch 1600 loss: 0.038758
Epoch 1650 loss: 0.037633
Epoch 1700 loss: 0.037167
Epoch 1750 loss: 0.036234
Epoch 1800 loss: 0.035238
Epoch 1850 loss: 0.034266
Epoch 1900 loss: 0.033301
Epoch 1950 loss: 0.031697
Epoch 2000 loss: 0.031217
Epoch 2050 loss: 0.030073
Epoch 2100 loss: 0.029144
Epoch 2150 loss: 0.028427
Epoch 2200 loss: 0.027817
Epoch 2250 loss: 0.027696
Epoch 2300 loss: 0.026979
Epoch 2350 loss: 0.026631
Epoch 2400 loss: 0.026351
Epoch 2450 loss: 0.025369
Epoch 2500 loss: 0.025343
Epoch 2550 loss: 0.025172
Epoch 2600 loss: 0.024360
Epoch 2650 loss: 0.024442
Epoch 2700 loss: 0.024184
Epoch 2750 loss: 0.024068
Epoch 2800 loss: 0.023908
Epoch 2850 loss: 0.023486
Epoch 2900 loss: 0.023201
Epoch 2950 loss: 0.023358
Epoch 3000 loss: 0.022979
Epoch 3050 loss: 0.023087
Epoch 3100 loss: 0.023546
Epoch 3150 loss: 0.022470
Epoch 3200 loss: 0.022867
Epoch 3250 loss: 0.022026
Epoch 3300 loss: 0.022216
Epoch 3350 loss: 0.022131
Epoch 3400 loss: 0.021967
Epoch 3450 loss: 0.021718
Epoch 3500 loss: 0.021446
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-50-f1ef09ed7c66> in <cell line: 7>()
      5 celeb_hp = HyperNet(activation = nn.ReLU() ).to(device)
      6 celeb_targetnet = NN().to(device)
----> 7 celeb_loss_list = train_hypernet(celeb_hp, celeb_targetnet, lr = 3e-4, hyper_data = images_hyper_data, epochs = 4500, verbose=True)
      8 # reduced learning rate
      9 end_time = time.time()

<ipython-input-35-3255bd0398d6> in train_hypernet(hypernet, target_net, lr, hyper_data, epochs, verbose)
     34 
     35             loss = criterion(outputs, train_img_Y_shi)
---> 36             loss.backward()
     37             optimizer.step()
     38             running_loss += loss.item()

/usr/local/lib/python3.10/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    490                 inputs=inputs,
    491             )
--> 492         torch.autograd.backward(
    493             self, gradient, retain_graph, create_graph, inputs=inputs
    494         )

KeyboardInterrupt: 
In [ ]:
torch.save(celeb_hp.state_dict(), "celeb_hp_hypernet.pt")
In [ ]:
torch.save(celeb_hp.state_dict(), "celeb_hp_hypernet.pt")
In [ ]:
plt.plot(celeb_loss_list)
plt.xlabel("Epochs")
plt.ylabel("loss")
In [ ]:
test_contextp = 50
celebno = 2
image_X, image_Y, scaler_X = create_scaled_cmap(celeba_dataset[celebno], rt = False)

sh_indexi = torch.randperm(img0_X.shape[0])
test_cont_img0_X = image_X[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_cont_img0_Y = image_Y[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_context_data = torch.cat((test_cont_img0_X, test_cont_img0_Y), dim = 1)
test_context_data.shape
Out[ ]:
torch.Size([2048, 5])
In [ ]:
params = celeb_hp(test_context_data).mean(dim=0)
params.shape
Out[ ]:
torch.Size([8707])
In [ ]:
flat_weights,unravel_fn= ravel_pytree(dict(celeb_targetnet.named_parameters()))
parameter_dictionary = unravel_fn(params)

outputs = torch.func.functional_call(celeb_targetnet, parameter_dictionary, img0_X)
print(output.shape)
torch.Size([64, 64, 3])
In [ ]:
plt.imshow(output.detach().cpu())
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[ ]:
<matplotlib.image.AxesImage at 0x7cd7bb6545b0>
In [ ]:
output.permute(1, 2, 0).shape
Out[ ]:
torch.Size([64, 3, 64])
In [ ]:
img0_X
Out[ ]:
tensor([[-1.0000, -1.0000],
        [-0.9683, -1.0000],
        [-0.9365, -1.0000],
        ...,
        [ 0.9365,  1.0000],
        [ 0.9683,  1.0000],
        [ 1.0000,  1.0000]], device='cuda:0')
In [ ]:
plt.imshow(output.permute(1,2,0).detach().cpu())
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-79-0f48977eb536> in <cell line: 1>()
----> 1 plt.imshow(output.permute(1,2,0).detach().cpu())

/usr/local/lib/python3.10/dist-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)
   2693         interpolation_stage=None, filternorm=True, filterrad=4.0,
   2694         resample=None, url=None, data=None, **kwargs):
-> 2695     __ret = gca().imshow(
   2696         X, cmap=cmap, norm=norm, aspect=aspect,
   2697         interpolation=interpolation, alpha=alpha, vmin=vmin,

/usr/local/lib/python3.10/dist-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1440     def inner(ax, *args, data=None, **kwargs):
   1441         if data is None:
-> 1442             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1443 
   1444         bound = new_sig.bind(ax, *args, **kwargs)

/usr/local/lib/python3.10/dist-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
   5663                               **kwargs)
   5664 
-> 5665         im.set_data(X)
   5666         im.set_alpha(alpha)
   5667         if im.get_clip_path() is None:

/usr/local/lib/python3.10/dist-packages/matplotlib/image.py in set_data(self, A)
    708         if not (self._A.ndim == 2
    709                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
--> 710             raise TypeError("Invalid shape {} for image data"
    711                             .format(self._A.shape))
    712 

TypeError: Invalid shape (64, 3, 64) for image data
In [ ]:
params = hypernet(context).mean(dim=0)

        flat_weights,unravel_fn= ravel_pytree(dict(targetnet.named_parameters()))

        parameter_dictionary = unravel_fn(params)

        outputs = torch.func.functional_call(targetnet, parameter_dictionary, X)
        print(output.shape)
        outputs = outputs.reshape(num_channels, height, width)
        outputs = outputs.permute(1, 2, 0)
In [ ]:
plot_reconstructed_and_original_image(celeba_dataset[celebno],
                                      celeb_hp, celeb_targetnet, img0_X, test_context_data, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])

Neural Process¶

In [ ]:
import torch
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.distributions as dist

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [ ]:
import os

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

root = "/content/drive/MyDrive/Dataset/CelebA_Dataset"
celeba_dataset = []

for filename in os.listdir(root):
    if filename.endswith('.jpg'):
        image_path = os.path.join(root, filename)
        image = datasets.folder.default_loader(image_path)
        image = transform(image)
        celeba_dataset.append(image)
In [ ]:
len(celeba_dataset)
Out[ ]:
761
In [ ]:
img0 = transform(datasets.folder.default_loader("/content/drive/MyDrive/Dataset/CelebA_Dataset/000001.jpg"))
print(img0.shape)
plt.imshow(img0.permute(1,2,0))
torch.Size([3, 64, 64])
Out[ ]:
<matplotlib.image.AxesImage at 0x7c8a54cd0f10>
In [ ]:
from sklearn import preprocessing

def create_scaled_cmap(img, rt = False):
    """
    Creates a scaled image and a scaled colormap
    """
    img= img
    num_channels, height, width = img.shape

    # Create a 2D grid of (x,y) coordinates
    x_coords = torch.arange(width).repeat(height, 1)
    y_coords = torch.arange(height).repeat(width, 1).t()
    x_coords = x_coords.reshape(-1)
    y_coords = y_coords.reshape(-1)

    X = torch.stack([x_coords, y_coords], dim=1).float().to(device)
    # Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
    if rt == True:
      Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
    else:
      Y = img.reshape(-1, num_channels).float().to(device)

    scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(X.cpu())
    scaled_X = torch.tensor(scaler_X.transform(X.cpu())).to(device).float()

    return scaled_X, Y, scaler_X
In [ ]:
img0_X_scaled, img0_Y, scaler_X = create_scaled_cmap(img0)
img0_X_scaled.shape, img0_Y.shape
Out[ ]:
(torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
# This neural process model will work for both encoder and decoder
# Smaller model
s = 64

class NN(nn.Module):
    def _init_siren(self, activation_scale):
        self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
        for layers in [self.fc2, self.fc3, self.fc5]:
            layers.weight.data.uniform_(-np.sqrt(6/self.fc2.in_features)/activation_scale,
                                        np.sqrt(6/self.fc2.in_features)/activation_scale)

    def __init__(self, inp_dim = 5, activation=torch.sin, n_out=3, activation_scale=1.0):
        super().__init__()
        self.activation = activation
        self.activation_scale = activation_scale
        self.fc1 = nn.Linear(inp_dim, s)
        self.fc2 = nn.Linear(s, s)
        self.fc3 = nn.Linear(s, s)
        self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
        if self.activation == torch.sin:
            # init weights and biases for sine activation
            self._init_siren(activation_scale=self.activation_scale)

    def forward(self, x):
        x = self.activation(self.activation_scale*self.fc1(x))
        x = self.activation(self.activation_scale*self.fc2(x))
        x = self.activation(self.activation_scale*self.fc3(x))
        # x = self.activation(self.activation_scale*self.fc4(x))
        return self.fc5(x)
In [ ]:
img0_X_scaled.shape, img0_Y.shape
Out[ ]:
(torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
torch.manual_seed(40)
img0_np_data = []
datano = 10
context_percent_start = 10
context_percent_end = 50

for i in range(datano):
    contp = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
    sh_indexi = torch.randperm(img0_X_scaled.shape[0])
    cont_img0_X_shi = img0_X_scaled[sh_indexi][0:int(len(img0_X_scaled)*contp/100)]
    cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X_scaled)*contp/100)]
    context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)

    train_img0_X_shi = img0_X_scaled #[sh_indexi]
    train_img0_Y_shi = img0_Y #[sh_indexi]

    datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
    img0_np_data.append(datas)
In [ ]:
len(img0_np_data), img0_np_data[0][0].shape, img0_np_data[0][1].shape, img0_np_data[0][2].shape
Out[ ]:
(10, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))

Training the dataset

In [ ]:
K = 500
encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder2 =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

encoder_output = encoder2(img0_np_data[0][0])
enocoded_rep = encoder_output.mean(dim=0).repeat(img0_np_data[0][1].shape[0], 1)
print(img0_np_data[0][0].shape, encoder_output.shape, enocoded_rep.shape)

train_rep = torch.concat((img0_np_data[0][1], enocoded_rep), dim = 1)
print(train_rep.shape)
outputs = decoder2(train_rep)
print(outputs.shape)

def softplus(std, beta = 1, threshold = 20):
  return (1/beta)*(torch.log(1 + torch.exp(std)))

print(type(softplus(outputs[:,3:])))

def normal_loss(mean, log_sigma, actual_val):
  # std = softplus(std)
  sigma = 0.1 + 0.9*softplus(log_sigma)
  # type(std) # sigma of sigma*(1/2)
  return -dist.Normal(mean, sigma).log_prob(actual_val).mean()

loss = normal_loss(outputs[:,:3], outputs[:,3:],  img0_np_data[0][2])
loss
torch.Size([409, 5]) torch.Size([409, 500]) torch.Size([4096, 500])
torch.Size([4096, 502])
torch.Size([4096, 6])
<class 'torch.Tensor'>
Out[ ]:
tensor(0.8217, device='cuda:0', grad_fn=<NegBackward0>)
In [ ]:
def softplus(std, beta = 1, threshold = 20):
  return (1/beta)*(torch.log(1 + torch.exp(std)))

def normal_loss(mean, log_sigma, actual_val):
  sigma = 0.1 + 0.9*softplus(log_sigma)
  return -dist.Normal(mean, sigma).log_prob(actual_val).mean()

def train_np(encoder, decoder, np_data, lr, epochs, verbose=True):
    """
    net: torch.nn.Module
    lr: float
    X: torch.Tensor of shape (num_samples, 2)
    Y: torch.Tensor of shape (num_samples, 3)
    """

    # criterion = nn.MSELoss()
    optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()], lr=lr)
    datano = len(np_data)
    total_loss = []

    for epoch in range(epochs):
        running_loss = 0.0
        for i in range(len(np_data)):
            context_data, train_dog_X_shi, train_dog_Y_shi = np_data[i]

            optimizer.zero_grad()
            encoder_output = encoder(context_data)
            enocoded_rep = encoder_output.mean(dim=0).repeat(train_dog_X_shi.shape[0], 1)

            train_rep = torch.concat((train_dog_X_shi, enocoded_rep), dim = 1)
            outputs = decoder(train_rep)

            # loss = criterion(outputs[:,:3], train_dog_Y_shi)
            loss = normal_loss(outputs[:,:3], outputs[:,3:], train_dog_Y_shi)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # if verbose and epoch % 1 == 0:
            #   # print(encoder.state_dict()["fc1.weight"][0:2, 0:2], decoder.state_dict()["fc1.weight"][0:2, 0:2])
            #   print(f"Epoch {epoch} loss: {loss.item():.6f}")

        total_loss.append(running_loss)

        if verbose and epoch % 50 == 0:
        # if verbose and epoch % 1 == 0:
            # print(encoder.state_dict()["fc1.weight"][0:2, 0:2], decoder.state_dict()["fc1.weight"][0:2, 0:2])
            print(f"Epoch {epoch} loss: {running_loss/datano:.6f}")

    return total_loss
In [ ]:
import time

torch.manual_seed(0)
# can experiment with k
K = 500
encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder2 =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

start_time = time.time()

loss_list = train_np(encoder2, decoder2, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 1.337767
Epoch 50 loss: 0.094281
Epoch 100 loss: 0.091550
Epoch 150 loss: 0.085551
Epoch 200 loss: 0.078786
Epoch 250 loss: 0.075036
Epoch 300 loss: 0.068512
Epoch 350 loss: 0.057991
Epoch 400 loss: 0.032758
Epoch 450 loss: -0.116420
Epoch 500 loss: -0.232012
Epoch 550 loss: -0.305583
Epoch 600 loss: -0.349394
Epoch 650 loss: -0.386280
Epoch 700 loss: -0.409056
Epoch 750 loss: -0.445883
Epoch 800 loss: -0.460155
Epoch 850 loss: -0.474480
Epoch 900 loss: -0.505245
Epoch 950 loss: -0.487306
Epoch 1000 loss: -0.353294
Epoch 1050 loss: -0.533210
Epoch 1100 loss: -0.557027
Epoch 1150 loss: -0.426544
Epoch 1200 loss: -0.576683
Epoch 1250 loss: -0.502492
Epoch 1300 loss: -0.594420
Epoch 1350 loss: -0.607830
Epoch 1400 loss: -0.585463
Epoch 1450 loss: -0.637908
Epoch 1500 loss: -0.650928
Epoch 1550 loss: -0.464352
Epoch 1600 loss: -0.671424
Epoch 1650 loss: -0.626369
Epoch 1700 loss: -0.683977
Epoch 1750 loss: -0.699354
Epoch 1800 loss: -0.689624
Epoch 1850 loss: -0.718512
Epoch 1900 loss: -0.661768
Epoch 1950 loss: -0.701058
Epoch 2000 loss: -0.762148
Epoch 2050 loss: -0.741034
Epoch 2100 loss: -0.774478
Epoch 2150 loss: -0.762232
Epoch 2200 loss: -0.792295
Epoch 2250 loss: -0.753801
Epoch 2300 loss: -0.802900
Epoch 2350 loss: -0.579035
Epoch 2400 loss: -0.814304
Epoch 2450 loss: -0.759662
Epoch 2500 loss: -0.820780
Epoch 2550 loss: -0.829672
Epoch 2600 loss: -0.690268
Epoch 2650 loss: -0.836858
Epoch 2700 loss: -0.813810
Epoch 2750 loss: -0.848221
Epoch 2800 loss: -0.854806
Epoch 2850 loss: -0.763804
Epoch 2900 loss: -0.863535
Epoch 2950 loss: -0.871811
Epoch 3000 loss: -0.555845
Epoch 3050 loss: -0.878646
Epoch 3100 loss: -0.885052
Epoch 3150 loss: -0.739801
Epoch 3200 loss: -0.890640
Epoch 3250 loss: -0.889438
Epoch 3300 loss: -0.900909
Epoch 3350 loss: -0.782015
Epoch 3400 loss: -0.909211
Epoch 3450 loss: -0.911320
Epoch 3500 loss: -0.912907
Epoch 3550 loss: -0.922259
Epoch 3600 loss: -0.760309
Epoch 3650 loss: -0.929416
Epoch 3700 loss: -0.868978
Epoch 3750 loss: -0.927377
Epoch 3800 loss: -0.938281
Epoch 3850 loss: -0.939244
Epoch 3900 loss: -0.930846
Epoch 3950 loss: -0.939761
Epoch 4000 loss: -0.952568
Epoch 4050 loss: -0.924327
Epoch 4100 loss: -0.959611
Epoch 4150 loss: -0.962178
Epoch 4200 loss: -0.960338
Epoch 4250 loss: -0.967879
Epoch 4300 loss: -0.972552
Epoch 4350 loss: -0.968258
Epoch 4400 loss: -0.970368
Epoch 4450 loss: -0.931210
Epoch 4500 loss: -0.979965
Epoch 4550 loss: -0.755416
Epoch 4600 loss: -0.985051
Epoch 4650 loss: -0.989418
Epoch 4700 loss: -0.881498
Epoch 4750 loss: -0.993088
Epoch 4800 loss: -0.976235
Epoch 4850 loss: -0.996279
Epoch 4900 loss: -0.999940
Epoch 4950 loss: -0.921534
Epoch 5000 loss: -1.002399
Epoch 5050 loss: -1.006093
Epoch 5100 loss: -0.990569
Epoch 5150 loss: -1.008689
Epoch 5200 loss: -1.000557
Epoch 5250 loss: -0.952158
Epoch 5300 loss: -1.013231
Epoch 5350 loss: -0.943541
Epoch 5400 loss: -1.013609
Epoch 5450 loss: -1.007544
Epoch 5500 loss: -1.012452
Epoch 5550 loss: -1.018967
Epoch 5600 loss: -1.024164
Epoch 5650 loss: -0.907021
Epoch 5700 loss: -1.026700
Epoch 5750 loss: -1.028270
Epoch 5800 loss: -0.917459
Epoch 5850 loss: -1.030782
Epoch 5900 loss: -1.032855
Epoch 5950 loss: -0.976612
Epoch 6000 loss: -1.031469
Epoch 6050 loss: -1.032512
Epoch 6100 loss: -1.011055
Epoch 6150 loss: -1.040253
Epoch 6200 loss: -1.044241
Epoch 6250 loss: -0.965589
Epoch 6300 loss: -1.043470
Epoch 6350 loss: -1.048069
Epoch 6400 loss: -0.946986
Epoch 6450 loss: -1.050458
Epoch 6500 loss: -1.050313
Epoch 6550 loss: -0.924406
Epoch 6600 loss: -1.053265
Epoch 6650 loss: -1.056274
Epoch 6700 loss: -1.010929
Epoch 6750 loss: -1.058735
Epoch 6800 loss: -1.056873
Epoch 6850 loss: -0.995843
Epoch 6900 loss: -1.063446
Epoch 6950 loss: -1.004386
Epoch 7000 loss: -1.064945
Epoch 7050 loss: -1.051375
Epoch 7100 loss: -1.065526
Epoch 7150 loss: -1.067781
Epoch 7200 loss: -1.004980
Epoch 7250 loss: -1.072607
Epoch 7300 loss: -1.024302
Epoch 7350 loss: -1.074019
Epoch 7400 loss: -1.072818
Epoch 7450 loss: -1.053316
Epoch 7500 loss: -1.065668
Epoch 7550 loss: -0.961100
Epoch 7600 loss: -1.073951
Epoch 7650 loss: -1.082017
Epoch 7700 loss: -1.001805
Epoch 7750 loss: -1.077227
Epoch 7800 loss: -1.083825
Epoch 7850 loss: -1.081036
Epoch 7900 loss: -1.087032
Epoch 7950 loss: -1.051721
Training time: 316.47 seconds
In [ ]:
# loss_list2 = train_np(encoder2, decoder2, img0_np_data, lr=1e-3, epochs=4000, verbose=True)
In [ ]:
torch.save(encoder2.state_dict(), "img0_encoder2_hypernet.pt")
torch.save(decoder2.state_dict(), "img0_decoder2_hypernet.pt")
In [ ]:
plt.plot(loss_list)
plt.xlabel("Iterations")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
def plot_np_image(encoder, decoder, datano, key, context_percent, title = ""):
    torch.manual_seed(key)
    scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[datano], rt = True)
    #now, the context has also changed
    sh_index = torch.randperm(scaled_X.shape[0])

    # if int(len(scaled_X)*context_percent/100) > 6000:
    #     print("context size is too big")
    #     return None

    cont_img_X = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
    cont_img_Y = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
    context = torch.cat((cont_img_X, cont_img_Y), dim = 1)

    encoder_output = encoder(context)
    encoded_rep = encoder_output.mean(dim=0).repeat(scaled_X.shape[0], 1)

    train_rep = torch.cat((scaled_X, encoded_rep), dim = 1)
    output = decoder(train_rep)
    var = 0.1 + 0.9*softplus(output[:,3:]**2) #output[:,3:]**2
    output = output[:,:3]

    num_channels, height, width = celeba_dataset[datano].shape
    output = output.reshape(num_channels, height, width)
    output = output.permute(1, 2, 0)

    fig = plt.figure(figsize=(12, 6))
    gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 1])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])
    ax2 = plt.subplot(gs[2])
    ax3 = plt.subplot(gs[3])

    ax0.imshow(celeba_dataset[datano].cpu().permute(1, 2, 0))
    ax0.set_title("Original Image")
    ax0.axis("off")

    ax1.imshow(output.detach().cpu())
    ax1.set_title("Reconstructed Image")
    ax1.axis("off")

    actual_img = torch.ones(celeba_dataset[datano].shape).permute(1, 2, 0)
    # actual_img = (celeba_dataset[datano][0].cpu()*0).permute(1, 2, 0)
    cont_img_X_unscaled = scaler_X.inverse_transform(cont_img_X.cpu())
    for i,x in enumerate(cont_img_X_unscaled):
        actual_img[int(x[1]+0.5), int(x[0]+0.5)] = torch.tensor(cont_img_Y[i].cpu().detach().numpy())

    ax2.imshow(actual_img)
    # ax2.scatter(cont_img_X[:, 0].detach().cpu(), cont_img_X[:, 1].detach().cpu(), s=10, c='r')
    ax2.set_title("Context Points")
    ax2.axis("off")

    var = var.reshape(num_channels, height, width)
    var = var.permute(1, 2, 0)
    ax3.imshow(var.detach().cpu())
    ax3.set_title("Variance")
    ax3.axis("off")

    fig.suptitle(title, y=0.9)
    plt.tight_layout()
In [ ]:
plot_np_image(encoder2, decoder2, datano = 0, key = 41, context_percent = 20, title = "20% context K = 500")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
import time

torch.manual_seed(0)
# can experiment with k
K = 200
encoder3 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder3 =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

start_time = time.time()

loss_list3 = train_np(encoder3, decoder3, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.985983
Epoch 50 loss: 0.094548
Epoch 100 loss: 0.094396
Epoch 150 loss: 0.094360
Epoch 200 loss: 0.094433
Epoch 250 loss: 0.086920
Epoch 300 loss: 0.076237
Epoch 350 loss: 0.068208
Epoch 400 loss: 0.061454
Epoch 450 loss: 0.051312
Epoch 500 loss: 0.002067
Epoch 550 loss: -0.171355
Epoch 600 loss: -0.258447
Epoch 650 loss: -0.351243
Epoch 700 loss: -0.318880
Epoch 750 loss: -0.430725
Epoch 800 loss: -0.443474
Epoch 850 loss: -0.442860
Epoch 900 loss: -0.471925
Epoch 950 loss: -0.354867
Epoch 1000 loss: -0.487307
Epoch 1050 loss: -0.485608
Epoch 1100 loss: -0.503016
Epoch 1150 loss: -0.463705
Epoch 1200 loss: -0.516598
Epoch 1250 loss: -0.413420
Epoch 1300 loss: -0.511021
Epoch 1350 loss: -0.534636
Epoch 1400 loss: -0.533560
Epoch 1450 loss: -0.550606
Epoch 1500 loss: -0.501212
Epoch 1550 loss: -0.564356
Epoch 1600 loss: -0.565869
Epoch 1650 loss: -0.566113
Epoch 1700 loss: -0.583977
Epoch 1750 loss: -0.594680
Epoch 1800 loss: -0.601825
Epoch 1850 loss: -0.603967
Epoch 1900 loss: -0.630418
Epoch 1950 loss: -0.629093
Epoch 2000 loss: -0.642894
Epoch 2050 loss: -0.628538
Epoch 2100 loss: -0.660608
Epoch 2150 loss: -0.671160
Epoch 2200 loss: -0.680648
Epoch 2250 loss: -0.657872
Epoch 2300 loss: -0.689692
Epoch 2350 loss: -0.688498
Epoch 2400 loss: -0.700271
Epoch 2450 loss: -0.712134
Epoch 2500 loss: -0.718355
Epoch 2550 loss: -0.538734
Epoch 2600 loss: -0.725233
Epoch 2650 loss: -0.732642
Epoch 2700 loss: -0.717478
Epoch 2750 loss: -0.742109
Epoch 2800 loss: -0.751388
Epoch 2850 loss: -0.749122
Epoch 2900 loss: -0.760358
Epoch 2950 loss: -0.767813
Epoch 3000 loss: -0.562763
Epoch 3050 loss: -0.774568
Epoch 3100 loss: -0.715387
Epoch 3150 loss: -0.785606
Epoch 3200 loss: -0.517251
Epoch 3250 loss: -0.788772
Epoch 3300 loss: -0.793651
Epoch 3350 loss: -0.796951
Epoch 3400 loss: -0.798574
Epoch 3450 loss: -0.807807
Epoch 3500 loss: -0.807657
Epoch 3550 loss: -0.814677
Epoch 3600 loss: -0.698318
Epoch 3650 loss: -0.820988
Epoch 3700 loss: -0.786445
Epoch 3750 loss: -0.827814
Epoch 3800 loss: -0.832542
Epoch 3850 loss: -0.604163
Epoch 3900 loss: -0.835482
Epoch 3950 loss: -0.836838
Epoch 4000 loss: -0.842324
Epoch 4050 loss: -0.741636
Epoch 4100 loss: -0.848315
Epoch 4150 loss: -0.809173
Epoch 4200 loss: -0.851881
Epoch 4250 loss: -0.854729
Epoch 4300 loss: -0.839670
Epoch 4350 loss: -0.856656
Epoch 4400 loss: -0.861078
Epoch 4450 loss: -0.862912
Epoch 4500 loss: -0.866721
Epoch 4550 loss: -0.860558
Epoch 4600 loss: -0.866106
Epoch 4650 loss: -0.840890
Epoch 4700 loss: -0.871417
Epoch 4750 loss: -0.877970
Epoch 4800 loss: -0.685538
Epoch 4850 loss: -0.881884
Epoch 4900 loss: -0.885006
Epoch 4950 loss: -0.852265
Epoch 5000 loss: -0.888602
Epoch 5050 loss: -0.887027
Epoch 5100 loss: -0.886767
Epoch 5150 loss: -0.886860
Epoch 5200 loss: -0.894198
Epoch 5250 loss: -0.892315
Epoch 5300 loss: -0.898365
Epoch 5350 loss: -0.742378
Epoch 5400 loss: -0.901049
Epoch 5450 loss: -0.901360
Epoch 5500 loss: -0.902117
Epoch 5550 loss: -0.899075
Epoch 5600 loss: -0.907264
Epoch 5650 loss: -0.902203
Epoch 5700 loss: -0.906515
Epoch 5750 loss: -0.904309
Epoch 5800 loss: -0.914406
Epoch 5850 loss: -0.914590
Epoch 5900 loss: -0.911589
Epoch 5950 loss: -0.917603
Epoch 6000 loss: -0.883401
Epoch 6050 loss: -0.920532
Epoch 6100 loss: -0.912924
Epoch 6150 loss: -0.922672
Epoch 6200 loss: -0.924746
Epoch 6250 loss: -0.883664
Epoch 6300 loss: -0.924836
Epoch 6350 loss: -0.925393
Epoch 6400 loss: -0.926199
Epoch 6450 loss: -0.875144
Epoch 6500 loss: -0.930490
Epoch 6550 loss: -0.924770
Epoch 6600 loss: -0.855461
Epoch 6650 loss: -0.934010
Epoch 6700 loss: -0.907176
Epoch 6750 loss: -0.935576
Epoch 6800 loss: -0.932074
Epoch 6850 loss: -0.937173
Epoch 6900 loss: -0.924708
Epoch 6950 loss: -0.939628
Epoch 7000 loss: -0.888961
Epoch 7050 loss: -0.940577
Epoch 7100 loss: -0.938378
Epoch 7150 loss: -0.941364
Epoch 7200 loss: -0.915255
Epoch 7250 loss: -0.941812
Epoch 7300 loss: -0.944872
Epoch 7350 loss: -0.917466
Epoch 7400 loss: -0.946301
Epoch 7450 loss: -0.812070
Epoch 7500 loss: -0.947961
Epoch 7550 loss: -0.949372
Epoch 7600 loss: -0.750076
Epoch 7650 loss: -0.949743
Epoch 7700 loss: -0.951500
Epoch 7750 loss: -0.931081
Epoch 7800 loss: -0.937685
Epoch 7850 loss: -0.953335
Epoch 7900 loss: -0.952670
Epoch 7950 loss: -0.954446
Training time: 330.42 seconds
In [ ]:
plt.plot(loss_list3)
plt.xlabel("Iterations")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
plot_np_image(encoder3, decoder3, datano = 0, key = 41, context_percent = 20, title = "20% context K = 200")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Higher variance for lesser K

In [ ]:
import time

torch.manual_seed(0)
# can experiment with k
K = 1000
encoder4 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder4 =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

start_time = time.time()

loss_list4 = train_np(encoder4, decoder4, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 3.247491
Epoch 50 loss: 0.095072
Epoch 100 loss: 0.094696
Epoch 150 loss: 0.095501
Epoch 200 loss: 0.092775
Epoch 250 loss: 0.105793
Epoch 300 loss: 0.078307
Epoch 350 loss: 0.960751
Epoch 400 loss: 0.472558
Epoch 450 loss: 0.091825
Epoch 500 loss: 0.088676
Epoch 550 loss: 0.087290
Epoch 600 loss: 0.085511
Epoch 650 loss: 0.516163
Epoch 700 loss: 0.225607
Epoch 750 loss: 0.090968
Epoch 800 loss: 0.100098
Epoch 850 loss: 0.088283
Epoch 900 loss: 0.080500
Epoch 950 loss: 0.080115
Epoch 1000 loss: 0.364954
Epoch 1050 loss: 0.084921
Epoch 1100 loss: 0.078579
Epoch 1150 loss: 0.076788
Epoch 1200 loss: 0.081680
Epoch 1250 loss: 0.078172
Epoch 1300 loss: 0.069085
Epoch 1350 loss: 0.064416
Epoch 1400 loss: 0.296190
Epoch 1450 loss: 0.061623
Epoch 1500 loss: 0.050863
Epoch 1550 loss: 0.006544
Epoch 1600 loss: 0.018891
Epoch 1650 loss: -0.085501
Epoch 1700 loss: -0.113087
Epoch 1750 loss: -0.225822
Epoch 1800 loss: -0.237431
Epoch 1850 loss: -0.276261
Epoch 1900 loss: -0.392886
Epoch 1950 loss: -0.414422
Epoch 2000 loss: -0.444768
Epoch 2050 loss: -0.464655
Epoch 2100 loss: -0.506322
Epoch 2150 loss: -0.525111
Epoch 2200 loss: -0.593114
Epoch 2250 loss: -0.599869
Epoch 2300 loss: -0.608259
Epoch 2350 loss: -0.640627
Epoch 2400 loss: -0.578990
Epoch 2450 loss: -0.684827
Epoch 2500 loss: -0.749774
Epoch 2550 loss: -0.768743
Epoch 2600 loss: -0.657960
Epoch 2650 loss: -0.783812
Epoch 2700 loss: -0.628939
Epoch 2750 loss: -0.817731
Epoch 2800 loss: -0.764647
Epoch 2850 loss: -0.847248
Epoch 2900 loss: -0.866466
Epoch 2950 loss: -0.645291
Epoch 3000 loss: -0.770316
Epoch 3050 loss: -0.874695
Epoch 3100 loss: -0.901799
Epoch 3150 loss: -0.864566
Epoch 3200 loss: -0.919835
Epoch 3250 loss: -0.858242
Epoch 3300 loss: -0.905720
Epoch 3350 loss: -0.878959
Epoch 3400 loss: -0.859680
Epoch 3450 loss: -0.489101
Epoch 3500 loss: -0.952121
Epoch 3550 loss: -0.941618
Epoch 3600 loss: -0.894202
Epoch 3650 loss: -0.915352
Epoch 3700 loss: -0.953912
Epoch 3750 loss: -0.722261
Epoch 3800 loss: -0.959066
Epoch 3850 loss: -0.724264
Epoch 3900 loss: -0.996326
Epoch 3950 loss: -0.999575
Epoch 4000 loss: -0.811981
Epoch 4050 loss: -0.868557
Epoch 4100 loss: -0.856301
Epoch 4150 loss: -0.871763
Epoch 4200 loss: -0.902275
Epoch 4250 loss: -0.870940
Epoch 4300 loss: -1.014184
Epoch 4350 loss: -0.979585
Epoch 4400 loss: -1.015361
Epoch 4450 loss: -0.978255
Epoch 4500 loss: -0.834727
Epoch 4550 loss: -0.914563
Epoch 4600 loss: -1.051477
Epoch 4650 loss: -1.026398
Epoch 4700 loss: -1.011141
Epoch 4750 loss: -1.044407
Epoch 4800 loss: -0.998591
Epoch 4850 loss: -1.062151
Epoch 4900 loss: -0.842098
Epoch 4950 loss: -1.062960
Epoch 5000 loss: -0.959533
Epoch 5050 loss: -1.064741
Epoch 5100 loss: -0.958881
Epoch 5150 loss: -1.076099
Epoch 5200 loss: -0.927501
Epoch 5250 loss: -1.054025
Epoch 5300 loss: -0.935108
Epoch 5350 loss: -1.082859
Epoch 5400 loss: -0.941729
Epoch 5450 loss: -1.084702
Epoch 5500 loss: -0.919471
Epoch 5550 loss: -1.005004
Epoch 5600 loss: -1.025450
Epoch 5650 loss: -1.093088
Epoch 5700 loss: -0.580471
Epoch 5750 loss: -1.090108
Epoch 5800 loss: -1.047101
Epoch 5850 loss: -1.105287
Epoch 5900 loss: -1.054601
Epoch 5950 loss: -1.036265
Epoch 6000 loss: -0.904680
Epoch 6050 loss: -0.792468
Epoch 6100 loss: -1.111913
Epoch 6150 loss: -0.964587
Epoch 6200 loss: -1.117073
Epoch 6250 loss: -1.122940
Epoch 6300 loss: -0.303059
Epoch 6350 loss: -1.100989
Epoch 6400 loss: -1.123604
Epoch 6450 loss: -1.111797
Epoch 6500 loss: -1.066253
Epoch 6550 loss: -1.021821
Epoch 6600 loss: -1.128251
Epoch 6650 loss: -1.057411
Epoch 6700 loss: -1.126673
Epoch 6750 loss: -1.105785
Epoch 6800 loss: -0.909334
Epoch 6850 loss: -0.948805
Epoch 6900 loss: -1.133896
Epoch 6950 loss: -1.073505
Epoch 7000 loss: -1.136847
Epoch 7050 loss: -1.073588
Epoch 7100 loss: -1.009102
Epoch 7150 loss: -1.140496
Epoch 7200 loss: -1.034243
Epoch 7250 loss: -1.141284
Epoch 7300 loss: -1.034567
Epoch 7350 loss: -0.994248
Epoch 7400 loss: -1.137774
Epoch 7450 loss: -1.035262
Epoch 7500 loss: -1.130714
Epoch 7550 loss: -1.141972
Epoch 7600 loss: -1.146804
Epoch 7650 loss: -1.139201
Epoch 7700 loss: -1.146221
Epoch 7750 loss: -1.093645
Epoch 7800 loss: -0.959759
Epoch 7850 loss: -1.150753
Epoch 7900 loss: -0.820156
Epoch 7950 loss: -1.154293
Training time: 317.62 seconds
In [ ]:
plt.plot(loss_list4)
plt.xlabel("Iterations")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
plot_np_image(encoder4, decoder4, datano = 0, key = 41, context_percent = 20, title = "20% context K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Training for multiple images¶

10 dp, 8000 epochs = 4 min, 200 dp, 3500 epochs = 35 min

In [ ]:
200*3500*4/(10*8000)
Out[ ]:
35.0
In [ ]:
torch.manual_seed(42)
celeb_np_data = []
celebno = 50
datano = 4 # 100
context_percent_start = 10
context_percent_end = 50

for i in range(celebno):
    scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[i])
    for j in range(datano):
      context_percent = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
      sh_index = torch.randperm(scaled_X.shape[0])

      cont_img_Xi = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
      cont_img_Yi = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
      context_data = torch.cat((cont_img_Xi, cont_img_Yi), dim = 1)

      train_img_Xi = scaled_X[sh_index]
      train_img_Yi = Y[sh_index]

      datas = [context_data, train_img_Xi, train_img_Yi]
      celeb_np_data.append(datas)
In [ ]:
len(celeb_np_data) ,celeb_np_data[0][0].shape, celeb_np_data[0][1].shape, celeb_np_data[0][2].shape
Out[ ]:
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
import time

torch.manual_seed(0)
K = 1000 # can change based on experiment
celeb_encoder = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
celeb_decoder =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

start_time = time.time()
celeb_loss_list = train_np(celeb_encoder, celeb_decoder, celeb_np_data, lr=1e-3, epochs= 3500, verbose=True)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.561838
Epoch 50 loss: 0.062497
Epoch 100 loss: 0.049615
Epoch 150 loss: -0.013535
Epoch 200 loss: -0.008642
Epoch 250 loss: -0.063802
Epoch 300 loss: -0.103127
Epoch 350 loss: -0.132578
Epoch 400 loss: -0.150624
Epoch 450 loss: -0.161050
Epoch 500 loss: -0.163209
Epoch 550 loss: -0.184284
Epoch 600 loss: -0.186982
Epoch 650 loss: -0.179719
Epoch 700 loss: -0.203929
Epoch 750 loss: -0.195653
Epoch 800 loss: -0.209538
Epoch 850 loss: -0.202713
Epoch 900 loss: -0.212755
Epoch 950 loss: -0.203270
Epoch 1000 loss: -0.210044
Epoch 1050 loss: -0.209445
Epoch 1100 loss: -0.223716
Epoch 1150 loss: -0.230180
Epoch 1200 loss: -0.219785
Epoch 1250 loss: -0.221717
Epoch 1300 loss: -0.229299
Epoch 1350 loss: -0.354515
Epoch 1400 loss: -0.472582
Epoch 1450 loss: -0.508528
Epoch 1500 loss: -0.522056
Epoch 1550 loss: -0.524381
Epoch 1600 loss: -0.533695
Epoch 1650 loss: -0.543165
Epoch 1700 loss: -0.532182
Epoch 1750 loss: -0.565511
Epoch 1800 loss: -0.564247
Epoch 1850 loss: -0.541689
Epoch 1900 loss: -0.565920
Epoch 1950 loss: -0.556600
Epoch 2000 loss: -0.563846
Epoch 2050 loss: -0.586694
Epoch 2100 loss: -0.570260
Epoch 2150 loss: -0.581060
Epoch 2200 loss: -0.584690
Epoch 2250 loss: -0.589820
Epoch 2300 loss: -0.581520
Epoch 2350 loss: -0.588501
Epoch 2400 loss: -0.588752
Epoch 2450 loss: -0.604451
Epoch 2500 loss: -0.589753
Epoch 2550 loss: -0.565134
Epoch 2600 loss: -0.581715
Epoch 2650 loss: -0.610918
Epoch 2700 loss: -0.577788
Epoch 2750 loss: -0.594610
Epoch 2800 loss: -0.606971
Epoch 2850 loss: -0.606520
Epoch 2900 loss: -0.607718
Epoch 2950 loss: -0.594243
Epoch 3000 loss: -0.576128
Epoch 3050 loss: -0.603738
Epoch 3100 loss: -0.586131
Epoch 3150 loss: -0.601565
Epoch 3200 loss: -0.614933
Epoch 3250 loss: -0.616316
Epoch 3300 loss: -0.609108
Epoch 3350 loss: -0.624380
Epoch 3400 loss: -0.611612
Epoch 3450 loss: -0.622663
Training time: 2831.85 seconds
In [ ]:
torch.save(celeb_encoder.state_dict(), "celeb_encoder1_hypernet.pt")
torch.save(celeb_decoder.state_dict(), "celeb_decoder1_hypernet.pt")
In [ ]:
plt.plot(celeb_loss_list)
plt.xlabel("Iterations")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 10, title = "10% context")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 50, title = "50% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 1, title = "1% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
plot_np_image(celeb_encoder, celeb_decoder, datano = 5, key = 41, context_percent = 100, title = "100% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
torch.manual_seed(42)
celeb_np_data = []
celebno = 50
datano = 4 # 100
context_percent_start = 10
context_percent_end = 50

for i in range(celebno):
    scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[i])
    for j in range(datano):
      context_percent = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
      sh_index = torch.randperm(scaled_X.shape[0])

      cont_img_Xi = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
      cont_img_Yi = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
      context_data = torch.cat((cont_img_Xi, cont_img_Yi), dim = 1)

      train_img_Xi = scaled_X[sh_index]
      train_img_Yi = Y[sh_index]

      datas = [context_data, train_img_Xi, train_img_Yi]
      celeb_np_data.append(datas)
In [ ]:
len(celeb_np_data) ,celeb_np_data[0][0].shape, celeb_np_data[0][1].shape, celeb_np_data[0][2].shape
Out[ ]:
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
In [ ]:
import time

torch.manual_seed(0)
K = 1000 # can change based on experiment
celeb_encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
celeb_decoder2 =  NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)

start_time = time.time()
celeb_loss_list2 = train_np(celeb_encoder2, celeb_decoder2, celeb_np_data, lr=1e-3, epochs= 6000, verbose=True)
end_time = time.time()

print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.479685
Epoch 50 loss: 0.168451
Epoch 100 loss: 0.089300
Epoch 150 loss: 0.095178
Epoch 200 loss: 0.069774
Epoch 250 loss: 0.023266
Epoch 300 loss: -0.022957
Epoch 350 loss: -0.045284
Epoch 400 loss: -0.066094
Epoch 450 loss: -0.083994
Epoch 500 loss: -0.102610
Epoch 550 loss: -0.112196
Epoch 600 loss: -0.118784
Epoch 650 loss: -0.126227
Epoch 700 loss: -0.124259
Epoch 750 loss: -0.124713
Epoch 800 loss: -0.130140
Epoch 850 loss: -0.138738
Epoch 900 loss: -0.143222
Epoch 950 loss: -0.142787
Epoch 1000 loss: -0.149772
Epoch 1050 loss: -0.145987
Epoch 1100 loss: -0.183938
Epoch 1150 loss: -0.297947
Epoch 1200 loss: -0.390763
Epoch 1250 loss: -0.417370
Epoch 1300 loss: -0.438389
Epoch 1350 loss: -0.451033
Epoch 1400 loss: -0.494825
Epoch 1450 loss: -0.488731
Epoch 1500 loss: -0.514693
Epoch 1550 loss: -0.542114
Epoch 1600 loss: -0.530419
Epoch 1650 loss: -0.525035
Epoch 1700 loss: -0.531059
Epoch 1750 loss: -0.543996
Epoch 1800 loss: -0.556629
Epoch 1850 loss: -0.549620
Epoch 1900 loss: -0.563347
Epoch 1950 loss: -0.552585
Epoch 2000 loss: -0.555627
Epoch 2050 loss: -0.569125
Epoch 2100 loss: -0.581341
Epoch 2150 loss: -0.550785
Epoch 2200 loss: -0.577023
Epoch 2250 loss: -0.594627
Epoch 2300 loss: -0.564444
Epoch 2350 loss: -0.584900
Epoch 2400 loss: -0.597912
Epoch 2450 loss: -0.580392
Epoch 2500 loss: -0.586710
Epoch 2550 loss: -0.581483
Epoch 2600 loss: -0.580890
Epoch 2650 loss: -0.604032
Epoch 2700 loss: -0.589725
Epoch 2750 loss: -0.588449
Epoch 2800 loss: -0.588983
Epoch 2850 loss: -0.596554
Epoch 2900 loss: -0.607114
Epoch 2950 loss: -0.596147
Epoch 3000 loss: -0.599351
Epoch 3050 loss: -0.602844
Epoch 3100 loss: -0.606371
Epoch 3150 loss: -0.612399
Epoch 3200 loss: -0.623650
Epoch 3250 loss: -0.598958
Epoch 3300 loss: -0.614483
Epoch 3350 loss: -0.624101
Epoch 3400 loss: -0.621168
Epoch 3450 loss: -0.605707
Epoch 3500 loss: -0.607578
Epoch 3550 loss: -0.626449
Epoch 3600 loss: -0.598677
Epoch 3650 loss: -0.631481
Epoch 3700 loss: -0.631332
Epoch 3750 loss: -0.623939
Epoch 3800 loss: -0.618339
Epoch 3850 loss: -0.623305
Epoch 3900 loss: -0.630795
Epoch 3950 loss: -0.650639
Epoch 4000 loss: -0.631177
Epoch 4050 loss: -0.634612
Epoch 4100 loss: -0.637233
Epoch 4150 loss: -0.638567
Epoch 4200 loss: -0.621981
Epoch 4250 loss: -0.613415
Epoch 4300 loss: -0.647063
Epoch 4350 loss: -0.645125
Epoch 4400 loss: -0.633680
Epoch 4450 loss: -0.614909
Epoch 4500 loss: -0.649519
Epoch 4550 loss: -0.656437
Epoch 4600 loss: -0.647346
Epoch 4650 loss: -0.643765
Epoch 4700 loss: -0.647531
Epoch 4750 loss: -0.647105
Epoch 4800 loss: -0.637083
Epoch 4850 loss: -0.666762
Epoch 4900 loss: -0.662398
Epoch 4950 loss: -0.665553
Epoch 5000 loss: -0.655768
Epoch 5050 loss: -0.649421
Epoch 5100 loss: -0.654450
Epoch 5150 loss: -0.664157
Epoch 5200 loss: -0.657926
Epoch 5250 loss: -0.659550
Epoch 5300 loss: -0.625096
Epoch 5350 loss: -0.659383
Epoch 5400 loss: -0.661412
Epoch 5450 loss: -0.678888
Epoch 5500 loss: -0.663640
Epoch 5550 loss: -0.674701
Epoch 5600 loss: -0.631961
Epoch 5650 loss: -0.684443
Epoch 5700 loss: -0.635810
Epoch 5750 loss: -0.646522
Epoch 5800 loss: -0.658710
Epoch 5850 loss: -0.668171
Epoch 5900 loss: -0.669239
Epoch 5950 loss: -0.660185
Training time: 4817.50 seconds
In [ ]:
torch.save(celeb_encoder2.state_dict(), "celeb_encoder2_hypernet.pt")
torch.save(celeb_decoder2.state_dict(), "celeb_decoder2_hypernet.pt")
In [ ]:
plt.plot(celeb_loss_list2)
plt.xlabel("Iterations")
plt.ylabel("loss")
Out[ ]:
Text(0, 0.5, 'loss')
In [ ]:
plot_np_image(celeb_encoder2, celeb_decoder2, datano = 10, key = 39, context_percent = 50, title = "50% context")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Question 4 [2 marks]¶

Write the Random walk Metropolis Hastings algorithms from scratch.

Take 1000 samples using below given log probs and compare the mean and covariance matrix with hamiltorch’s standard HMC and emcee’s Metropolis Hastings implementation. Use 500 samples as the burn/warm up samples.

Also check the relation between acceptance ratio and the sigma of the proposal distribution in your from scratch implementation. Use the log likelihood function given below.

In [1]:
import torch
import torch.distributions as dist

def log_likelihood(omega):
    mean = torch.tensor([0., 0.])
    stddev = torch.tensor([0.5, 1.])
    return dist.MultivariateNormal(mean, torch.diag(stddev**2)).log_prob(omega).sum()
In [2]:
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

# The grid
x1 = np.linspace(-5, 5, 100)
x2 = np.linspace(-5, 5, 100)
X1, X2 = np.meshgrid(x1, x2)
Z = np.zeros_like(X1)
Z_exp = np.zeros_like(X1)

# Evaluate the function at each point on the grid
for i in range(X1.shape[0]):
    for j in range(X1.shape[1]):
        Z[i, j] = log_likelihood(torch.tensor([X1[i, j], X2[i, j]]))
        Z_exp[i, j] = np.exp(Z[i, j])

# Plot the contours
fig, ax = plt.subplots(figsize=(6, 4))
contour = ax.contour(X1, X2, Z, levels=50, cmap='viridis')
colorbar = plt.colorbar(contour, ax=ax, label='log likelihood')
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_title('Log likelihood contour')
Out[2]:
Text(0.5, 1.0, 'Log likelihood contour')
In [3]:
# 3-D Plot of the contours
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X1, X2, Z_exp, cmap='viridis')
ax.set_xlabel('x1'); ax.set_ylabel('x2'); ax.set_zlabel('likelihood')
ax.set_title("Probability Density Function")
plt.show()

Metropolis Hastings Algorithm¶

The acceptance ration of MHA is:-

$$ r = \frac{p(\theta^{*})/ J_{t}(\theta^{*}|\theta^{t-1})}{p(\theta^{t-1}) / J_{t}(\theta^{t-1}|\theta^{*})} $$

Where $J_{t}(\theta^{*}|\theta^{t-1})$ is the jump distribution at time $t$.

In [4]:
import torch

def random_walk_metropolis_hastings(log_likelihood, initial_state, num_samples, jump_function):
    samples = [initial_state]
    current_state = initial_state
    accepted_count = 0

    for i in range(num_samples):
        # Propose a new state by adding Gaussian noise
        # proposal = current_state + proposal_stddev * torch.randn(current_state.shape)
        proposal = jump_function(current_state).sample()

        # Calculate the acceptance ratio
        log_prob_current = log_likelihood(current_state)
        log_prob_jump = jump_function(current_state).log_prob(proposal)
        log_prob_proposal = log_likelihood(proposal)
        log_prob_rev_jump = jump_function(proposal).log_prob(current_state)

        acceptance_ratio = torch.exp(log_prob_proposal - log_prob_jump - log_prob_current + log_prob_rev_jump)

        # Accept or reject the proposal
        if torch.rand(1) < acceptance_ratio:
            current_state = proposal
            accepted_count += 1
        else:
            i -= 1

        samples.append(current_state)

    acceptance_rate = accepted_count / num_samples
    return torch.stack(samples[500:]), torch.stack(samples), acceptance_rate
In [9]:
torch.manual_seed(0)

normal_stddev = torch.tensor([[0.1,0],[0,0.1]])#0.1

def gaussian_jump(current_state):
    return dist.MultivariateNormal(current_state, normal_stddev)

initial_state = torch.tensor([-4, 3.5])
jump_function = gaussian_jump
num_samples = 1500

samples, all_samples, accp_rate = random_walk_metropolis_hastings(log_likelihood, initial_state, num_samples, jump_function)
In [10]:
accp_rate, samples.shape, all_samples.shape
Out[10]:
(0.756, torch.Size([1001, 2]), torch.Size([1501, 2]))
In [11]:
def plot_samples(samples, title, lines = False):
    x1 = np.linspace(-5, 5, 100)
    x2 = np.linspace(-5, 5, 100)
    X1, X2 = np.meshgrid(x1, x2)
    Z = np.zeros_like(X1)
    Z_exp = np.zeros_like(X1)

    # Evaluate the function at each point on the grid
    for i in range(X1.shape[0]):
        for j in range(X1.shape[1]):
            Z[i, j] = log_likelihood(torch.tensor([X1[i, j], X2[i, j]]))
            Z_exp[i, j] = np.exp(Z[i, j])

    # Plot the contours
    fig, ax = plt.subplots(figsize=(8, 6))
    if lines:
        ax.plot(samples.numpy()[:, 0], samples.numpy()[:, 1], alpha=0.5, color = "red", label='Samples')
    # else:
    ax.scatter(samples.numpy()[:, 0], samples.numpy()[:, 1], s=2, alpha=0.5, label='Samples')
    contour = ax.contour(X1, X2, Z, levels=50, cmap='viridis')
    colorbar = plt.colorbar(contour, ax=ax, label='log likelihood')
    ax.legend()
    fig.suptitle(title)
    ax.set_xlabel('x1')
    ax.set_ylabel('x2')

    plt.plot
In [12]:
plot_samples(all_samples, f'Random Walk Metropolis Hastings Acceptance rate = {accp_rate:.2f}', lines = True)
In [13]:
mean = torch.mean(samples, dim=0)
cov = torch.tensor(np.cov(samples.T))
samples.shape, mean, cov
Out[13]:
(torch.Size([1001, 2]),
 tensor([0.0460, 0.1631]),
 tensor([[0.2572, 0.0197],
         [0.0197, 0.7183]], dtype=torch.float64))
In [17]:
import hamiltorch
torch.manual_seed(0)

params_init = torch.tensor([-4, 3.5])
samples_hmc = hamiltorch.sample(log_prob_func= log_likelihood, params_init=params_init, num_samples=1500,
                               step_size=0.1, num_steps_per_sample=10)
Sampling (Sampler.HMC; Integrator.IMPLICIT)
Time spent  | Time remain.| Progress             | Samples   | Samples/sec
0d:00:00:09 | 0d:00:00:00 | #################### | 1500/1500 | 158.76       
Acceptance Rate 0.99
In [18]:
samples_hmc = torch.stack(samples_hmc)[500:]
samples_hmc.shape
Out[18]:
torch.Size([1000, 2])
In [19]:
hmc_mean = torch.mean(samples_hmc, dim=0)
hmc_cov = torch.tensor(np.cov(samples_hmc.T))
samples_hmc.shape, hmc_mean, hmc_cov
Out[19]:
(torch.Size([1000, 2]),
 tensor([0.0072, 0.0218]),
 tensor([[0.2324, 0.0075],
         [0.0075, 0.9991]], dtype=torch.float64))
In [20]:
plot_samples(samples_hmc, f'Hamiltorch HMC Samples, Acceptance rate = 1.00', lines = True)

Trying Emcee¶

In [21]:
import emcee

means = np.array([0, 0])
cov = np.array([[0.1, 0.0],
                [0.0, 0.1]])

def log_prob(x, mu, cov):
    diff = x - mu
    return -0.5 * np.dot(diff, np.linalg.solve(cov, diff))

ndim = 2
nwalkers = 5
p0 = np.random.rand(nwalkers, ndim)

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[means, cov])
In [22]:
state = sampler.run_mcmc(p0, 100)
sampler.reset()

result = sampler.run_mcmc(state, 200)
In [23]:
emcee_accprate = np.mean(sampler.acceptance_fraction)
print(
    "Mean acceptance fraction: {0:.3f}".format(
        emcee_accprate
    )
)
Mean acceptance fraction: 0.728
In [24]:
samples_emcee = sampler.get_chain(flat=True)
samples_emcee.shape
Out[24]:
(1000, 2)
In [25]:
emcee_mean = np.array([np.mean(samples_emcee[:,0]),np.mean(samples_emcee[:,1])])
emcee_cov = np.cov(samples_emcee.T)
samples_emcee.shape, emcee_mean, emcee_cov
Out[25]:
((1000, 2),
 array([-6.58492212e-05,  1.04466032e-02]),
 array([[0.08967083, 0.0057612 ],
        [0.0057612 , 0.09591614]]))
In [26]:
plot_samples(torch.tensor(samples_emcee), f'Hamiltorch HMC Samples, Acceptance rate = {emcee_accprate:.3f}', lines = True)